aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD510
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h2
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc388
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc260
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc267
-rw-r--r--tensorflow/contrib/lite/kernels/add_test.cc96
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc (renamed from tensorflow/contrib/lite/kernels/arg_max.cc)82
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max_test.cc (renamed from tensorflow/contrib/lite/kernels/arg_max_test.cc)89
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc77
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc745
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc612
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc514
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc96
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/cast_test.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc178
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc539
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc63
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc407
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc476
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc108
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc301
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc50
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc591
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc72
-rw-r--r--tensorflow/contrib/lite/kernels/div_test.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc92
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h10
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc127
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc80
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc63
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc110
-rw-r--r--tensorflow/contrib/lite/kernels/exp.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc113
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims_test.cc83
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc95
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant_test.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc146
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div_test.cc90
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc189
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc302
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc55
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h5
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD238
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h145
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h55
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc157
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc349
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc232
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h62
-rw-r--r--tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc334
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc251
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h61
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h131
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h242
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h7700
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h1872
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h123
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc229
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h38
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h4649
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc275
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h80
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc173
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h89
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h118
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h326
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h2120
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc94
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h60
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h4865
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/softmax.h179
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc138
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc236
-rw-r--r--tensorflow/contrib/lite/kernels/internal/spectrogram.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h100
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h119
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h102
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc262
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.cc107
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.h103
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h766
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h43
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm.cc1316
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc664
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/log_softmax_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc134
-rw-r--r--tensorflow/contrib/lite/kernels/logical_test.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc573
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.cc912
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.h79
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc1809
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mean.cc271
-rw-r--r--tensorflow/contrib/lite/kernels/mean_test.cc219
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc195
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc98
-rw-r--r--tensorflow/contrib/lite/kernels/neg.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/neg_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc199
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot_test.cc182
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h52
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc32
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc135
-rw-r--r--tensorflow/contrib/lite/kernels/pack_test.cc154
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/pad_test.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h2
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc151
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc143
-rw-r--r--tensorflow/contrib/lite/kernels/pow_test.cc117
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc513
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc975
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc98
-rw-r--r--tensorflow/contrib/lite/kernels/register.h9
-rw-r--r--tensorflow/contrib/lite/kernels/relu1.cc59
-rw-r--r--tensorflow/contrib/lite/kernels/relu1_test.cc79
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc84
-rw-r--r--tensorflow/contrib/lite/kernels/reshape_test.cc37
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc34
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/select_test.cc37
-rw-r--r--tensorflow/contrib/lite/kernels/shape.cc93
-rw-r--r--tensorflow/contrib/lite/kernels/shape_test.cc95
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc43
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc142
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc17
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc158
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc275
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc155
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc46
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice_test.cc65
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc114
-rw-r--r--tensorflow/contrib/lite/kernels/sub_test.cc58
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc373
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc211
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc17
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h53
-rw-r--r--tensorflow/contrib/lite/kernels/test_util_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc195
-rw-r--r--tensorflow/contrib/lite/kernels/tile_test.cc256
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc26
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc34
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv_test.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc344
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc1823
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc80
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc133
-rw-r--r--tensorflow/contrib/lite/kernels/unpack_test.cc225
-rw-r--r--tensorflow/contrib/lite/kernels/zeros_like.cc73
-rw-r--r--tensorflow/contrib/lite/kernels/zeros_like_test.cc78
185 files changed, 37209 insertions, 14456 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index b7291dd379..d2d8073abd 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -6,13 +6,29 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android")
+
+# Suppress warnings that are introduced by Eigen Tensor.
+EXTRA_EIGEN_COPTS = select({
+ "//tensorflow:ios": [
+ "-Wno-error=invalid-partial-specialization",
+ "-Wno-error=reorder",
+ ],
+ "//tensorflow:windows": [
+ "/DEIGEN_HAS_C99_MATH",
+ "/DEIGEN_AVOID_STL_ARRAY",
+ ],
+ "//conditions:default": ["-Wno-error=reorder"],
+})
tf_cc_test(
name = "optional_tensor_test",
size = "small",
srcs = ["optional_tensor_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -46,11 +62,12 @@ cc_library(
hdrs = [
"eigen_support.h",
],
- copts = tflite_copts(),
+ copts = tflite_copts() + EXTRA_EIGEN_COPTS,
deps = [
":op_macros",
- "//tensorflow/contrib/lite:context",
- "//third_party/eigen3",
+ "//tensorflow/contrib/lite:arena_planner",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -65,7 +82,7 @@ cc_library(
copts = tflite_copts(),
deps = [
":op_macros",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@gemmlowp",
],
)
@@ -76,7 +93,7 @@ cc_library(
"activation_functor.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -96,9 +113,9 @@ cc_library(
"kernel_util.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:types",
],
)
@@ -106,7 +123,10 @@ tf_cc_test(
name = "kernel_util_test",
size = "small",
srcs = ["kernel_util_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":kernel_util",
"//tensorflow/contrib/lite/testing:util",
@@ -118,6 +138,7 @@ tf_cc_test(
name = "test_util_test",
size = "small",
srcs = ["test_util_test.cc"],
+ tags = ["no_oss"],
deps = [
":test_util",
"//tensorflow/contrib/lite/testing:util",
@@ -126,11 +147,20 @@ tf_cc_test(
)
cc_library(
- name = "builtin_ops",
+ name = "padding",
+ srcs = [],
+ hdrs = ["padding.h"],
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+cc_library(
+ name = "builtin_op_kernels",
srcs = [
"activations.cc",
"add.cc",
- "arg_max.cc",
+ "arg_min_max.cc",
"audio_spectrogram.cc",
"basic_rnn.cc",
"batch_to_space_nd.cc",
@@ -142,84 +172,120 @@ cc_library(
"conv.cc",
"depthwise_conv.cc",
"dequantize.cc",
+ "detection_postprocess.cc",
"div.cc",
"elementwise.cc",
"embedding_lookup.cc",
"embedding_lookup_sparse.cc",
"exp.cc",
+ "expand_dims.cc",
+ "fake_quant.cc",
"floor.cc",
+ "floor_div.cc",
"fully_connected.cc",
"gather.cc",
"hashtable_lookup.cc",
"l2norm.cc",
+ "layer_norm_lstm.cc",
"local_response_norm.cc",
+ "logical.cc",
"lsh_projection.cc",
"lstm.cc",
"maximum_minimum.cc",
- "mean.cc",
"mfcc.cc",
"mul.cc",
"neg.cc",
+ "one_hot.cc",
+ "pack.cc",
"pad.cc",
"pooling.cc",
- "register.cc",
+ "pow.cc",
+ "reduce.cc",
+ "relu1.cc",
"reshape.cc",
"resize_bilinear.cc",
"select.cc",
+ "shape.cc",
"skip_gram.cc",
"slice.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
+ "sparse_output_fully_connected.cc",
+ "sparse_to_dense.cc",
"split.cc",
"squeeze.cc",
"strided_slice.cc",
"sub.cc",
"svdf.cc",
+ "tile.cc",
"topk_v2.cc",
"transpose.cc",
"transpose_conv.cc",
"unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
+ "unpack.cc",
+ "zeros_like.cc",
],
hdrs = [
- "padding.h",
- "register.h",
],
- # Suppress warnings that are introduced by Eigen Tensor.
- copts = tflite_copts() + [
- "-Wno-error=reorder",
- ] + select({
- "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
- "//conditions:default": [
- ],
- }),
+ copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
+ visibility = ["//visibility:private"],
deps = [
":activation_functor",
":eigen_support",
":kernel_util",
+ ":lstm_eval",
":op_macros",
- "//tensorflow/contrib/lite:builtin_op_data",
+ ":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@farmhash_archive//:farmhash",
"@flatbuffers",
],
)
+cc_library(
+ name = "lstm_eval",
+ srcs = ["lstm_eval.cc"],
+ hdrs = ["lstm_eval.h"],
+ deps = [
+ ":op_macros",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ ],
+)
+
+cc_library(
+ name = "builtin_ops",
+ srcs = ["register.cc"],
+ hdrs = ["register.h"],
+ deps = [
+ ":builtin_op_kernels",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
tf_cc_test(
name = "audio_spectrogram_test",
size = "small",
srcs = ["audio_spectrogram_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -233,7 +299,61 @@ tf_cc_test(
name = "mfcc_test",
size = "small",
srcs = ["mfcc_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "detection_postprocess_test",
+ size = "small",
+ srcs = ["detection_postprocess_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "relu1_test",
+ size = "small",
+ srcs = ["relu1_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "sparse_output_fully_connected_test",
+ size = "small",
+ srcs = ["sparse_output_fully_connected_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -270,10 +390,11 @@ tf_cc_test(
)
tf_cc_test(
- name = "arg_max_test",
+ name = "arg_min_max_test",
size = "small",
- srcs = ["arg_max_test.cc"],
+ srcs = ["arg_min_max_test.cc"],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -288,7 +409,10 @@ tf_cc_test(
name = "div_test",
size = "small",
srcs = ["div_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -301,7 +425,10 @@ tf_cc_test(
name = "sub_test",
size = "small",
srcs = ["sub_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -314,7 +441,10 @@ tf_cc_test(
name = "transpose_test",
size = "small",
srcs = ["transpose_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -329,7 +459,10 @@ tf_cc_test(
name = "space_to_batch_nd_test",
size = "small",
srcs = ["space_to_batch_nd_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -342,7 +475,10 @@ tf_cc_test(
name = "batch_to_space_nd_test",
size = "small",
srcs = ["batch_to_space_nd_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -355,7 +491,10 @@ tf_cc_test(
name = "cast_test",
size = "small",
srcs = ["cast_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -400,6 +539,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
@@ -408,7 +548,10 @@ tf_cc_test(
name = "dequantize_test",
size = "small",
srcs = ["dequantize_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -435,7 +578,10 @@ tf_cc_test(
name = "bidirectional_sequence_lstm_test",
size = "small",
srcs = ["bidirectional_sequence_lstm_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -448,7 +594,10 @@ tf_cc_test(
name = "floor_test",
size = "small",
srcs = ["floor_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -461,7 +610,10 @@ tf_cc_test(
name = "elementwise_test",
size = "small",
srcs = ["elementwise_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -474,7 +626,10 @@ tf_cc_test(
name = "unidirectional_sequence_lstm_test",
size = "small",
srcs = ["unidirectional_sequence_lstm_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -488,6 +643,7 @@ tf_cc_test(
size = "small",
srcs = ["bidirectional_sequence_rnn_test.cc"],
tags = [
+ "no_oss",
"tflite_not_portable",
],
deps = [
@@ -502,7 +658,10 @@ tf_cc_test(
name = "unidirectional_sequence_rnn_test",
size = "small",
srcs = ["unidirectional_sequence_rnn_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -528,7 +687,26 @@ tf_cc_test(
name = "exp_test",
size = "small",
srcs = ["exp_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "fake_quant_test",
+ size = "small",
+ srcs = ["fake_quant_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -541,7 +719,10 @@ tf_cc_test(
name = "maximum_minimum_test",
size = "small",
srcs = ["maximum_minimum_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -551,10 +732,13 @@ tf_cc_test(
)
tf_cc_test(
- name = "mean_test",
+ name = "reduce_test",
size = "small",
- srcs = ["mean_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ srcs = ["reduce_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -580,7 +764,10 @@ tf_cc_test(
name = "pad_test",
size = "small",
srcs = ["pad_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -606,11 +793,14 @@ tf_cc_test(
name = "gather_test",
size = "small",
srcs = ["gather_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -620,11 +810,14 @@ tf_cc_test(
name = "topk_v2_test",
size = "small",
srcs = ["topk_v2_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -741,7 +934,10 @@ tf_cc_test(
name = "log_softmax_test",
size = "small",
srcs = ["log_softmax_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -779,6 +975,20 @@ tf_cc_test(
)
tf_cc_test(
+ name = "layer_norm_lstm_test",
+ size = "small",
+ srcs = ["layer_norm_lstm_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "lstm_test",
size = "small",
srcs = ["lstm_test.cc"],
@@ -822,7 +1032,10 @@ tf_cc_test(
name = "split_test",
size = "small",
srcs = ["split_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -835,7 +1048,10 @@ tf_cc_test(
name = "squeeze_test",
size = "small",
srcs = ["squeeze_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -848,7 +1064,10 @@ tf_cc_test(
name = "strided_slice_test",
size = "small",
srcs = ["strided_slice_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -858,12 +1077,30 @@ tf_cc_test(
)
tf_cc_test(
+ name = "tile_test",
+ size = "small",
+ srcs = ["tile_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "comparisons_test",
size = "small",
srcs = [
"comparisons_test.cc",
],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -878,7 +1115,10 @@ tf_cc_test(
name = "neg_test",
size = "small",
srcs = ["neg_test.cc"],
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
@@ -894,6 +1134,7 @@ tf_cc_test(
"select_test.cc",
],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -911,6 +1152,7 @@ tf_cc_test(
"slice_test.cc",
],
tags = [
+ "no_oss",
"tflite_not_portable_ios",
],
deps = [
@@ -925,9 +1167,163 @@ tf_cc_test(
name = "transpose_conv_test",
size = "small",
srcs = ["transpose_conv_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "expand_dims_test",
+ size = "small",
+ srcs = ["expand_dims_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "sparse_to_dense_test",
+ size = "small",
+ srcs = ["sparse_to_dense_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "shape_test",
+ size = "small",
+ srcs = ["shape_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "pow_test",
+ size = "small",
+ srcs = ["pow_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "pack_test",
+ size = "small",
+ srcs = ["pack_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "one_hot_test",
+ size = "small",
+ srcs = ["one_hot_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "logical_test",
+ size = "small",
+ srcs = ["logical_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "unpack_test",
+ size = "small",
+ srcs = ["unpack_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "floor_div_test",
+ size = "small",
+ srcs = ["floor_div_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "zeros_like_test",
+ size = "small",
+ srcs = ["zeros_like_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index 41ec3cca33..e075dc7054 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <cmath>
#include <cstdlib>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 4972159a05..9aed4f09b8 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -41,6 +40,11 @@ struct OpData {
int diff_min = 0;
};
+struct LogSoftmaxOpData : public OpData {
+ int32_t reverse_scaling_divisor = 0;
+ int32_t reverse_scaling_right_shift = 0;
+};
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
@@ -48,10 +52,19 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new OpData;
}
+void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
+ size_t length) {
+ return new LogSoftmaxOpData;
+}
+
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
+void LogSoftmaxFree(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<LogSoftmaxOpData*>(buffer);
+}
+
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -84,6 +97,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
&data->input_left_shift);
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ } else if (input->type == kTfLiteInt16) {
+ static constexpr int kInputIntegerBits = 3;
+ static constexpr int kOutputFractionalBits = 15;
+
+ // These operators are implemented in fixed-point arithmetic,
+ // which intrinsically wants symmetric ranges (zero_point==0)
+ // and power-of-two scales (power-of-two is abbreviated below as POT).
+ // While more general support would be possible by means of rescaling,
+ // that would add some overhead and some loss of accuracy and wouldn't
+ // be used at the moment as current quantized LSTM applications are
+ // happy with symmetric, power-of-two-scales quantization. So we just
+ // implement that narrow case only for now.
+
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input_scale_log2_rounded;
+ TF_LITE_ENSURE(context,
+ CheckedLog2(input->params.scale, &input_scale_log2_rounded));
+
+ int output_scale_log2_rounded;
+ TF_LITE_ENSURE(
+ context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
+ TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
+ -kOutputFractionalBits);
+
+ data->input_left_shift =
+ (15 - kInputIntegerBits) + input_scale_log2_rounded;
+ // Support for shifts is limited until we have a parameterized version of
+ // SaturatingRoundingMultiplyByPOT().
+ TF_LITE_ENSURE(context, data->input_left_shift >= 0);
+ TF_LITE_ENSURE(context, data->input_left_shift <= 1);
}
return context->ResizeTensor(context, output,
@@ -114,6 +159,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
&data->input_left_shift);
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ } else if (input->type == kTfLiteInt16) {
+ static constexpr int kInputIntegerBits = 3;
+ static constexpr int kOutputFractionalBits = 15;
+
+ // See comments in TanhPrepare about requiring zero_point==0
+ // and a power-of-two ("POT") scale.
+
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input_scale_log2_rounded;
+ TF_LITE_ENSURE(context,
+ CheckedLog2(input->params.scale, &input_scale_log2_rounded));
+
+ int output_scale_log2_rounded;
+ TF_LITE_ENSURE(
+ context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
+ TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
+ -kOutputFractionalBits);
+
+ data->input_left_shift =
+ (15 - kInputIntegerBits) + input_scale_log2_rounded;
+ // The int16 logistic implementation does not support shifting of the input.
+ TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0);
}
return context->ResizeTensor(context, output,
@@ -130,8 +199,8 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
- TF_LITE_ENSURE(context,
- NumDimensions(input) == 2 || NumDimensions(input) == 4);
+ const int num_dims = NumDimensions(input);
+ TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -150,6 +219,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
+TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
+ TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
+
+ static const double kBeta = 1.0;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessLogSoftmaxScalingExp(
+ kBeta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift,
+ &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift);
+ data->reverse_scaling_right_shift *= -1;
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -157,25 +254,25 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
- output->type = input->type;
-
// Currently only Float32 is supported
// TODO(ycling): Support other data types.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32);
+ output->type = input->type;
- // Currently, only support 4D `input` and 3D `alpha` with shape
- // (1, 1, channels).
- // TODO(impjdi): Support other cases where `alpha` is broadcastable
- // to `input`.
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]);
+ // PRelu (parameteric Relu) shares the same alpha value on "shared axis".
+ // This means it's always required to "broadcast" alpha values in PRelu.
+ TfLiteIntArray* output_size = nullptr;
+ TF_LITE_ENSURE_OK(
+ context, CalculateShapeForBroadcast(context, input, alpha, &output_size));
- return context->ResizeTensor(context, output,
- TfLiteIntArrayCopy(input->dims));
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+ // After broadcasting, the output shape should always be the same as the
+ // input shape.
+ TF_LITE_ENSURE(context, HaveSameShapes(input, output));
+
+ return kTfLiteOk;
}
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
@@ -191,7 +288,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
}
@@ -211,7 +309,8 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
}
@@ -229,7 +328,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
default:
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
}
@@ -247,16 +347,28 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
for (; in < in_end; in++, out++) *out = std::tanh(*in);
return kTfLiteOk;
} break;
+ case kTfLiteInt16: {
+ TanhParams params;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<int16_t>(input), GetTensorShape(output),
+ GetTensorData<int16_t>(output));
+ return kTfLiteOk;
+ } break;
case kTfLiteUInt8: {
- optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorDims(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ TanhParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
return kTfLiteOk;
} break;
default:
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
}
@@ -276,28 +388,35 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in));
break;
}
+ case kTfLiteInt16: {
+ LogisticParams params;
+ optimized_ops::Logistic(
+ params, GetTensorShape(input), GetTensorData<int16_t>(input),
+ GetTensorShape(output), GetTensorData<int16_t>(output));
+ break;
+ }
case kTfLiteUInt8: {
+ LogisticParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
optimized_ops::Logistic(
- GetTensorData<uint8_t>(input), GetTensorDims(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+ params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
break;
}
default:
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
}
-// Takes a 2D tensor and perform softmax along the second dimension.
-void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
- TfLiteSoftmaxParams* params) {
- const int batch_size = input->dims->data[0];
- const int input_size = input->dims->data[1];
- float* in = input->data.f;
- float* out = output->data.f;
+// Performs softmax along the input of size (input_size * batch_size).
+void Softmax(const float* in, const int input_size, const int batch_size,
+ const float beta, float* out) {
TF_LITE_ASSERT(input_size > 0);
// For each batch
@@ -311,7 +430,7 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
// Compute the normalized sum of exps.
float exp_sum = 0.0;
for (int i = 0; i < input_size; i++) {
- out[i] = std::exp((in[i] - max_coeff) * params->beta);
+ out[i] = std::exp((in[i] - max_coeff) * beta);
exp_sum += out[i];
}
@@ -327,6 +446,52 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
}
}
+// Takes a 1D tensor and performs softmax along it.
+void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int input_size = input->dims->data[0];
+ Softmax(input->data.f, input_size, 1, params->beta, output->data.f);
+}
+
+// Takes a 2D tensor and perform softmax along the last dimension.
+void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
+}
+
+// Takes a 3D tensor and perform softmax along the last dimension.
+void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ GetTensorData<float>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ GetTensorData<float>(output));
+}
+
+void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 1D
+ // tensor is 4D in a special way. We will convert a (Y) shape into a (1,
+ // 1, 1, Y) shape.
+ const int input_size = input->dims->data[0];
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(output));
+}
void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
@@ -335,27 +500,52 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// 1, 1, Y) shape.
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
- optimized_ops::Softmax(GetTensorData<uint8_t>(input),
- GetTensorDims({batch_size, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorDims({batch_size, 1, 1, input_size}));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params,
+ GetTensorShape({batch_size, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({batch_size, 1, 1, input_size}),
+ GetTensorData<uint8_t>(output));
+}
+
+void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ GetTensorData<uint8_t>(output));
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
- optimized_ops::Softmax(GetTensorData<float>(input), GetTensorDims(input),
- params->beta, GetTensorData<float>(output),
- GetTensorDims(output));
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(output),
+ GetTensorData<float>(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
- optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorDims(input),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -369,79 +559,107 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
// dimensions.
switch (input->type) {
case kTfLiteFloat32: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DFloat(input, output, params);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 2) {
Softmax2DFloat(input, output, params);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DFloat(input, output, params);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DFloat(input, output, params);
return kTfLiteOk;
}
- context->ReportError(context,
- "Only 2D and 4D tensors supported currently.");
+ context->ReportError(
+ context, "Only 1D, 2D and 4D tensors supported currently, got %dD.",
+ NumDimensions(input));
return kTfLiteError;
}
case kTfLiteUInt8: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 2) {
Softmax2DQuantized(input, output, params, data);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DQuantized(input, output, params, data);
return kTfLiteOk;
}
- context->ReportError(context,
- "Only 2D and 4D tensors supported currently.");
+ context->ReportError(
+ context, "Only 2D and 4D tensors supported currently, got %dD.",
+ NumDimensions(input));
return kTfLiteError;
}
default:
- context->ReportError(context,
- "Only float32 and uint8_t supported currently.");
+ context->ReportError(
+ context, "Only float32 and uint8_t supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
}
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ const LogSoftmaxOpData* data =
+ reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
+ SoftmaxParams op_params;
+ optimized_ops::LogSoftmax(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+ return kTfLiteOk;
+ }
+ case kTfLiteUInt8: {
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
+ op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::LogSoftmax(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
+ }
default:
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently., got %d",
+ input->type);
return kTfLiteError;
}
}
+template <typename T>
+T ApplyPrelu(T input, T alpha) {
+ return input >= 0.0 ? input : input * alpha;
+}
+
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
- const TfLiteTensor* output = GetOutput(context, node, 0);
-
+ TfLiteTensor* output = GetOutput(context, node, 0);
if (input->type != kTfLiteFloat32) {
- context->ReportError(context, "Only float32 supported currently.");
+ context->ReportError(context, "Only float32 supported currently, got %d.",
+ input->type);
return kTfLiteError;
}
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- const int batches = input->dims->data[0];
- const int height = input->dims->data[1];
- const int width = input->dims->data[2];
- const int channels = input->dims->data[3];
-
- TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels);
-
- const int n = batches * height * width * channels;
- for (int i = 0; i < n; ++i) {
- const float x = input->data.f[i];
- output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x;
- }
-
+ reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
+ GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(alpha),
+ GetTensorData<float>(alpha), GetTensorShape(output),
+ GetTensorData<float>(output), ApplyPrelu<float>);
return kTfLiteOk;
}
@@ -490,9 +708,9 @@ TfLiteRegistration* Register_SOFTMAX() {
}
TfLiteRegistration* Register_LOG_SOFTMAX() {
- static TfLiteRegistration r = {activations::Init, activations::Free,
- activations::GenericPrepare,
- activations::LogSoftmaxEval};
+ static TfLiteRegistration r = {
+ activations::LogSoftmaxInit, activations::LogSoftmaxFree,
+ activations::LogSoftmaxPrepare, activations::LogSoftmaxEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 50a84edd47..9fa47e190a 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-// TODO(ahentz): I don't quite understand the tradeoffs in the quantized
-// implementation of sigmoid and software, but a tolerance of twice the output
-// scale seems reasonable. We might want to change this if we have a better
-// theoretical bound.
+// Our fixed-point math function implementations have roughly 12 bits of
+// accuracy, when specialized to 16-bit fixed-point arithmetic.
+// That is purely an implementation compromise, it would have been possible
+// to get closer to 16 bits of accuracy but that would be more expensive,
+// and not needed for our purposes as ultimately the output is either
+// immediately down-quantized to 8 bits, or will typically be at the output
+// of the surrounding LSTM cell.
+// So we can require roughly 2^-12 accuracy when the output is 16-bit, and
+// we can more or less expect the full 2^-8 accuracy when the output is 8-bit.
+//
+// However, the representable output interval is often [-1, 1] (it has to be
+// for tanh, and even for logistic, when we implement it in fixed-point, we
+// typically have to do so on such a symmetric interval, e.g. ARM NEON only
+// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1]
+// is 2, our representable values are often diluted by a factor of 2, whence
+// the factor of 2 below.
const float kQuantizedTolerance = 2 * (1. / 256);
+const float kQuantizedToleranceInt16 = 2 * (1. / 4096);
class QuantizedActivationsOpModel : public BaseActivationsOpModel {
public:
using BaseActivationsOpModel::BaseActivationsOpModel;
+ template <typename T>
void SetInput(std::initializer_list<float> data) {
- QuantizeAndPopulate<uint8_t>(input_, data);
+ QuantizeAndPopulate<T>(input_, data);
}
- std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ template <typename T>
+
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ template <typename T>
std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
+ return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+ GetZeroPoint(output_));
}
};
@@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) {
}
TEST(QuantizedActivationsOpTest, Tanh) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
QuantizedActivationsOpModel m(
BuiltinOperator_TANH,
- /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8},
- /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1});
- m.SetInput({
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
-4, -2, 8, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.0, -0.999987, 0.964027, 0.999329, //
- -0.996078, -0.96402, 0.99999, 0.76159, //
+ -0.999329, -0.96402, 0.99999, 0.76159, //
},
- 4 * (1. / 256))));
- EXPECT_THAT(m.GetOutput(),
- ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226}));
+ kQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225}));
+}
+
+TEST(QuantizedActivationsOpTest, TanhInt16) {
+ const float kMin = -1;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_TANH,
+ /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<int16_t>({
+ 0, -6, 2, 4, //
+ -4, -2, 8, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.0, -0.999987, 0.964027, 0.999329, //
+ -0.999329, -0.96402, 0.99999, 0.76159, //
+ },
+ kQuantizedToleranceInt16)));
}
TEST(FloatActivationsOpTest, Sigmoid) {
@@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) {
QuantizedActivationsOpModel m(
BuiltinOperator_LOGISTIC,
/*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.5, 0.002473, 0.880797, 0.982014, //
0.952574, 0.119203, 0.999955, 0.731059, //
},
kQuantizedTolerance)));
- EXPECT_THAT(m.GetOutput(),
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188}));
}
+TEST(QuantizedActivationsOpTest, SigmoidInt16) {
+ const float kMin = -1;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<int16_t>({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ },
+ kQuantizedToleranceInt16)));
+}
+
TEST(FloatActivationsOpTest, Softmax4D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}});
@@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
QuantizedActivationsOpModel m(
0.1,
/*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, // depth = 0
3, -2, 10, 1, // depth = 1
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
.23463, .12877, .28658, .35003, //
@@ -258,6 +321,40 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
QuantizedActivationsOpModel m2(
0.1,
/*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10});
+ m2.SetInput<uint8_t>({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
+TEST(FloatActivationsOpTest, Softmax3D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 2}});
m2.SetInput({
0, -6, //
2, 4, //
@@ -265,14 +362,74 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
10, 1, //
});
m2.Invoke();
- EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {
- 0.645656, 0.354344, //
- 0.450166, 0.549834, //
- 0.622459, 0.377541, //
- 0.710949, 0.28905, //
- },
- kQuantizedTolerance)));
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax3D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10});
+ m2.SetInput<uint8_t>({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
+TEST(FloatActivationsOpTest, Softmax1D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {8}});
+ m.SetInput({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax1D) {
+ QuantizedActivationsOpModel m(0.1,
+ /*input=*/{TensorType_UINT8, {8}, -10, 10});
+ m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
+ 0.13281, 0.07813, 0.26563, 0.10938},
+ kQuantizedTolerance)));
}
TEST(FloatActivationsOpTest, Softmax2D) {
@@ -309,12 +466,12 @@ TEST(FloatActivationsOpTest, Softmax2D) {
TEST(QuantizedActivationsOpTest, Softmax2D) {
QuantizedActivationsOpModel m(0.1,
/*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
.23463, .12877, .28658, .35003, //
@@ -325,21 +482,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) {
// Same input, but a different shape.
QuantizedActivationsOpModel m2(0.1,
/*input=*/{TensorType_UINT8, {4, 2}, -10, 10});
- m2.SetInput({
+ m2.SetInput<uint8_t>({
0, -6, //
2, 4, //
3, -2, //
10, 1, //
});
m2.Invoke();
- EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {
- 0.645656, 0.354344, //
- 0.450166, 0.549834, //
- 0.622459, 0.377541, //
- 0.710949, 0.28905, //
- },
- kQuantizedTolerance)));
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
}
// This contains the same test values as the Softmax test, but reference answer
@@ -383,6 +541,28 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
})));
}
+TEST(QuantizedActivationsOpTest, LogSoftmax) {
+ const float kLogSoftmaxQuantizedTolerance = 16 / 256.0;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOG_SOFTMAX,
+ /*input=*/{TensorType_UINT8, {2, 4}, -10, 10},
+ /*output=*/{TensorType_UINT8, {}, 0, 0, 16. / 256, 255});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ -4.14297, -10.14297, -2.14297, -.142971, //
+ -7.00104, -12.00104, -.00104087, -9.00104, //
+ },
+ kLogSoftmaxQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111}));
+}
+
class PReluOpModel : public SingleOpModel {
public:
PReluOpModel(const TensorData& input, const TensorData& alpha) {
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index 7ca1e35489..b4393e8097 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -39,6 +39,23 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
+
+ // These fields are used in both the general 8-bit -> 8bit quantized path,
+ // and the special 16-bit -> 16bit quantized path
+ int input1_shift;
+ int input2_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+
+ // These fields are used only in the general 8-bit -> 8bit quantized path
+ int32 input1_multiplier;
+ int32 input2_multiplier;
+ int32 output_multiplier;
+ int output_shift;
+ int left_shift;
+ int32 input1_offset;
+ int32 input2_offset;
+ int32 output_offset;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -52,6 +69,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
@@ -74,89 +92,182 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size = TfLiteIntArrayCopy(input1->dims);
}
+ if (output->type == kTfLiteUInt8) {
+ // 8bit -> 8bit general quantized path, with general rescalings
+ data->input1_offset = -input1->params.zero_point;
+ data->input2_offset = -input2->params.zero_point;
+ data->output_offset = output->params.zero_point;
+ data->left_shift = 20;
+ const double twice_max_input_scale =
+ 2 * std::max(input1->params.scale, input2->params.scale);
+ const double real_input1_multiplier =
+ input1->params.scale / twice_max_input_scale;
+ const double real_input2_multiplier =
+ input2->params.scale / twice_max_input_scale;
+ const double real_output_multiplier =
+ twice_max_input_scale /
+ ((1 << data->left_shift) * output->params.scale);
+
+ QuantizeMultiplierSmallerThanOneExp(
+ real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
+
+ QuantizeMultiplierSmallerThanOneExp(
+ real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
+
+ QuantizeMultiplierSmallerThanOneExp(
+ real_output_multiplier, &data->output_multiplier, &data->output_shift);
+
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+
+ } else if (output->type == kTfLiteInt16) {
+ // 16bit -> 16bit special quantized path, supporting only a rather
+ // narrow case of quantization parameters: zero_points must all be 0
+ // ("symmetric quantization") and scales must be power-of-two (which
+ // we abbreviate as "POT" below). The intended use case for this path
+ // is in LSTM cells, where, due to the constraints of implementing
+ // some of the math in these LSTM cells in fixed-point arithmetic,
+ // we need to have such symmetric, power-of-two quantization
+ // (Fixed-point formats are inherently symmetric, power-of-two).
+ TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input1_scale_log2_rounded;
+ bool input1_scale_is_pot =
+ CheckedLog2(input1->params.scale, &input1_scale_log2_rounded);
+ TF_LITE_ENSURE(context, input1_scale_is_pot);
+
+ int input2_scale_log2_rounded;
+ bool input2_scale_is_pot =
+ CheckedLog2(input2->params.scale, &input2_scale_log2_rounded);
+ TF_LITE_ENSURE(context, input2_scale_is_pot);
+
+ int output_scale_log2_rounded;
+ bool output_scale_is_pot =
+ CheckedLog2(output->params.scale, &output_scale_log2_rounded);
+ TF_LITE_ENSURE(context, output_scale_is_pot);
+
+ data->input1_shift = input1_scale_log2_rounded - output_scale_log2_rounded;
+ data->input2_shift = input2_scale_log2_rounded - output_scale_log2_rounded;
+
+ // Shifting of one input is supported. The graph quantization should ensure
+ // that the other input matches the output.
+ TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0);
+ TF_LITE_ENSURE(context, data->input1_shift <= 0);
+ TF_LITE_ENSURE(context, data->input2_shift <= 0);
+
+ CalculateActivationRangeQuantized(context, params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
-void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_ADD(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
+void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_ADD(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
+ } else {
+ TF_LITE_ADD(reference_ops, Add, int32_t);
+ }
} else {
- TF_LITE_ADD(reference_ops, Add);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
+ } else {
+ TF_LITE_ADD(reference_ops, Add, float);
+ }
} else {
- TF_LITE_ADD(optimized_ops, Add);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add, float);
+ }
}
}
#undef TF_LITE_ADD
}
template <KernelType kernel_type>
-void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- auto input1_offset = -input1->params.zero_point;
- auto input2_offset = -input2->params.zero_point;
- auto output_offset = output->params.zero_point;
- const int left_shift = 20;
- const double twice_max_input_scale =
- 2 * std::max(input1->params.scale, input2->params.scale);
- const double real_input1_multiplier =
- input1->params.scale / twice_max_input_scale;
- const double real_input2_multiplier =
- input2->params.scale / twice_max_input_scale;
- const double real_output_multiplier =
- twice_max_input_scale / ((1 << left_shift) * output->params.scale);
-
- int32 input1_multiplier;
- int input1_shift;
- QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
- &input1_shift);
- int32 input2_multiplier;
- int input2_shift;
- QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
- &input2_shift);
- int32 output_multiplier;
- int output_shift;
- QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
- &output_shift);
-
- int32 output_activation_min, output_activation_max;
- CalculateActivationRangeUint8(params->activation, output,
- &output_activation_min, &output_activation_max);
-
-#define TF_LITE_ADD(type, opname) \
- type::opname(left_shift, GetTensorData<uint8_t>(input1), \
- GetTensorDims(input1), input1_offset, input1_multiplier, \
- input1_shift, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, input2_multiplier, \
- input2_shift, output_offset, output_multiplier, output_shift, \
- output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
- // The quantized version of Add doesn't support activations, so we
- // always use BroadcastAdd.
- if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
- } else {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
- }
+TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, const OpData* data,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2,
+ TfLiteTensor* output) {
+ if (output->type == kTfLiteUInt8) {
+#define TF_LITE_ADD(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.left_shift = data->left_shift; \
+ op_params.input1_offset = data->input1_offset; \
+ op_params.input1_multiplier = data->input1_multiplier; \
+ op_params.input1_shift = data->input1_shift; \
+ op_params.input2_offset = data->input2_offset; \
+ op_params.input2_multiplier = data->input2_multiplier; \
+ op_params.input2_shift = data->input2_shift; \
+ op_params.output_offset = data->output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+ // The quantized version of Add doesn't support activations, so we
+ // always use BroadcastAdd.
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
+ } else {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow);
+ }
#undef TF_LITE_ADD
+ } else if (output->type == kTfLiteInt16) {
+#define TF_LITE_ADD(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.input1_shift = data->input1_shift; \
+ op_params.input2_shift = data->input2_shift; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
+ // The quantized version of Add doesn't support activations, so we
+ // always use BroadcastAdd.
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops, Add);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add);
+ }
+#undef TF_LITE_ADD
+ }
+
+ return kTfLiteOk;
}
template <KernelType kernel_type>
@@ -168,15 +279,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalAddFloat<kernel_type>(context, node, params, data, input1, input2,
- output);
- } else if (output->type == kTfLiteUInt8) {
- EvalAddQuantized<kernel_type>(context, node, params, data, input1, input2,
- output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(context,
+ EvalAddQuantized<kernel_type>(context, node, params, data,
+ input1, input2, output));
} else {
context->ReportError(context,
- "Inputs and outputs not all float|uint8 types.");
+ "Inputs and outputs not all float|uint8|int16 types.");
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc
index 956d05bed5..0b58443211 100644
--- a/tensorflow/contrib/lite/kernels/add_test.cc
+++ b/tensorflow/contrib/lite/kernels/add_test.cc
@@ -52,6 +52,13 @@ class FloatAddOpModel : public BaseAddOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerAddOpModel : public BaseAddOpModel {
+ public:
+ using BaseAddOpModel::BaseAddOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
class QuantizedAddOpModel : public BaseAddOpModel {
public:
using BaseAddOpModel::BaseAddOpModel;
@@ -60,15 +67,26 @@ class QuantizedAddOpModel : public BaseAddOpModel {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
+
+ std::vector<float> GetDequantizedOutputInt16() {
+ return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
};
// for quantized Add, the error shouldn't exceed 2*step
-float GetTolerance(int min, int max) {
+float GetTolerance(float min, float max) {
float kQuantizedStep = (max - min) / 255.0;
float kQuantizedTolerance = 2.0 * kQuantizedStep;
return kQuantizedTolerance;
}
+float GetToleranceInt16(float min, float max) {
+ float kQuantizedStep = (max - min) / 32767.f;
+ float kQuantizedTolerance = 2.0 * kQuantizedStep;
+ return kQuantizedTolerance;
+}
+
TEST(FloatAddOpModel, NoActivation) {
FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -122,6 +140,57 @@ TEST(FloatAddOpModel, WithBroadcast) {
}
}
+TEST(IntegerAddOpModel, NoActivation) {
+ IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 4, 10, 13}));
+}
+
+TEST(IntegerAddOpModel, ActivationRELU_N1_TO_1) {
+ IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1}));
+}
+
+TEST(IntegerAddOpModel, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerAddOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5, 11, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 04, 10, 13, 22, 21}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerAddOpModel, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerAddOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-19, 3, 8, 9, 12, 21})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {
@@ -144,6 +213,31 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
}
}
+TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) {
+ const float kMin = -1.f;
+ const float kMax = 32767.f / 32768.f;
+ float kQuantizedTolerance = GetToleranceInt16(kMin, kMax);
+ std::vector<std::initializer_list<float>> inputs1 = {
+ {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}};
+ std::vector<std::initializer_list<float>> inputs2 = {
+ {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}};
+ std::vector<std::initializer_list<float>> results = {
+ {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
+ for (int i = 0; i < inputs1.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {}, kMin, kMax},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), inputs1[i]);
+ m.QuantizeAndPopulate<int16_t>(m.input2(), inputs2[i]);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutputInt16(),
+ ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance)))
+ << "With test number " << i;
+ }
+}
+
TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {{-0.8, 0.2, 0.9, 0.7},
diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 738d475f60..b91e348c27 100644
--- a/tensorflow/contrib/lite/kernels/arg_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -23,7 +23,7 @@ limitations under the License.
namespace tflite {
namespace ops {
namespace builtin {
-namespace arg_max {
+namespace arg_min_max {
constexpr int kInputTensor = 0;
constexpr int kAxis = 1;
@@ -52,7 +52,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output->type = kTfLiteInt64;
break;
default:
- context->ReportError(context, "Unknown index output data type");
+ context->ReportError(context, "Unknown index output data type: %d",
+ params->output_type);
return kTfLiteError;
}
@@ -64,7 +65,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
break;
default:
- context->ReportError(context, "Only float32 and int types are supported");
+ context->ReportError(
+ context,
+ "Unkonwn input type: %d, only float32 and int types are supported",
+ input->type);
return kTfLiteError;
}
@@ -76,30 +80,40 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
+template <typename T>
+std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
+ if (is_arg_max) {
+ return std::greater<T>();
+ } else {
+ return std::less<T>();
+ }
+}
+
// The current impl actually ignores the axis argument.
// Only determine the index of the maximum value in the last dimension.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMax(GetTensorData<axis_type>(axis), \
- GetTensorData<data_type>(input), GetTensorDims(input), \
- GetTensorData<output_type>(output), \
- GetTensorDims(output))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorData<axis_type>(axis), GetTensorShape(output), \
+ GetTensorData<output_type>(output), \
+ GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
break;
default:
return kTfLiteError;
@@ -108,13 +122,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
break;
default:
return kTfLiteError;
@@ -128,13 +142,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
break;
default:
return kTfLiteError;
@@ -143,13 +157,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
break;
default:
return kTfLiteError;
@@ -159,16 +173,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
}
-#undef TF_LITE_ARG_MAX
+#undef TF_LITE_ARG_MIN_MAX
return kTfLiteOk;
}
-} // namespace arg_max
+TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, false);
+}
+
+TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, true);
+}
+
+} // namespace arg_min_max
TfLiteRegistration* Register_ARG_MAX() {
- static TfLiteRegistration r = {nullptr, nullptr, arg_max::Prepare,
- arg_max::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMaxEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_ARG_MIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMinEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/arg_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
index 31b15fe19a..90e5fdc532 100644
--- a/tensorflow/contrib/lite/kernels/arg_max_test.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
@@ -24,16 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
template <typename T>
-class ArgMaxOpModel : public SingleOpModel {
+class ArgBaseOpModel : public SingleOpModel {
public:
- ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
- TensorType output_type, TensorType index_output_type) {
+ ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type) {
input_ = AddInput(input_type);
axis_ = AddInput(TensorType_INT32);
output_ = AddOutput(output_type);
- SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(builder_, index_output_type).Union());
- BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
int input() { return input_; }
@@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel {
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
int axis_;
int output_;
};
+template <typename T>
+class ArgMaxOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
+ CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
+template <typename T>
+class ArgMinOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
+ CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
TEST(ArgMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
@@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
}
+TEST(ArgMinOpTest, GetMinArgFloat) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
+ TensorType_INT32, TensorType_INT32);
+ model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgInt) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgMulDimensions) {
+ ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgOutput64) {
+ ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
+ TensorType_INT64);
+ model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 91d8dd3fa7..0d2d5e775f 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
index 8d460fdfc6..7e4ff6fc16 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index 0907547f9f..1aa27602e5 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -31,12 +31,14 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -46,14 +48,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -65,20 +69,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
output_size_array->data[0] = batch_size;
@@ -91,7 +87,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = kTfLiteUInt8;
@@ -114,6 +110,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size));
}
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
}
return kTfLiteOk;
@@ -145,14 +151,14 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
return kTfLiteOk;
}
-TfLiteStatus EvalQuantized(const TfLiteTensor* input,
- const TfLiteTensor* input_weights,
- const TfLiteTensor* recurrent_weights,
- const TfLiteTensor* bias,
- const TfLiteRNNParams* params,
- TfLiteTensor* input_scratch,
- TfLiteTensor* hidden_state_scratch,
- TfLiteTensor* hidden_state, TfLiteTensor* output) {
+TfLiteStatus EvalHybrid(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias, const TfLiteRNNParams* params,
+ TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1];
@@ -176,12 +182,14 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
reinterpret_cast<int8_t*>(input_scratch->data.uint8);
int8_t* quantized_hidden_state_ptr =
reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
kernel_utils::RnnBatchStep(
input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, params->activation, quantized_input_ptr,
- quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ quantized_hidden_state_ptr, scaling_factors_ptr, hidden_state_ptr_batch,
+ output_ptr_batch);
return kTfLiteOk;
}
@@ -193,7 +201,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->inputs->data[kHiddenStateTensor]];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// We already checked that weight types are consistent, so branch on one.
@@ -205,12 +214,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
- return EvalQuantized(input, input_weights, recurrent_weights, bias,
- params, input_quantized, hidden_state_quantized,
- hidden_state, output);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
+ return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
+ input_quantized, hidden_state_quantized,
+ scaling_factors, hidden_state, output);
}
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input_weights->type);
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
index 96465fcaf0..d179735404 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -181,15 +181,16 @@ class RNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ BuildInterpreter({{batches_, input_size_}, // input tensor
+ {units_, input_size_}, // weights tensor
+ {units_, units_}, // recurrent weights tensor
+ {units_}, // bias tensor
+ {batches_, units_}}); // hidden state tensor
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -210,14 +211,6 @@ class RNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -258,7 +251,6 @@ TEST(RnnOpTest, BlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
@@ -286,7 +278,6 @@ TEST(HybridRnnOpTest, BlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 262e1aeab1..fe2865dfb9 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -125,14 +125,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
- type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ type::BatchToSpaceND(GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.crops), \
GetTensorData<int32_t>(op_context.crops), \
- GetTensorDims(op_context.crops), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
@@ -163,8 +163,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
break;
default:
- context->ReportError(context,
- "Type is currently not supported by BatchToSpace.");
+ context->ReportError(
+ context, "Type %d is currently not supported by BatchToSpace.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_BATCH_TO_SPACE_ND
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 3425288f02..a326827b1e 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -21,12 +20,13 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -95,18 +95,54 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
// Projection bias tensor of size {n_output}
constexpr int kBwProjectionBiasTensor = 34; // Optional
-// Output tensors.
-constexpr int kFwOutputStateTensor = 0;
-constexpr int kFwCellStateTensor = 1;
-constexpr int kFwOutputTensor = 2;
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kFwInputActivationStateTensor = 35;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kFwInputCellStateTensor = 36;
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kBwInputActivationStateTensor = 37;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kBwInputCellStateTensor = 38;
+
+// Auxiliary input and weights when stacking.
+constexpr int kAuxInputTensor = 39; // Optional
+// Forward weights.
+constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
+constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional
+constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional
+constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional
+// Backward weights.
+constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional
+constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional
+constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional
+constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
-constexpr int kBwOutputStateTensor = 3;
-constexpr int kBwCellStateTensor = 4;
-constexpr int kBwOutputTensor = 5;
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
+
+// Temporary tensors.
+enum TemporaryTensor {
+ // Scratch buffers for input, forget, etc. gates
+ kFwScratchBuffer = 0,
+ kBwScratchBuffer = 1,
+ // Quantized tensors needed for the hybrid kernel.
+ kInputQuantized = 2,
+ kAuxInputQuantized = 3, // Quantized tensor needed for auxiliary input.
+ kFwActivationStateQuantized = 4,
+ kBwActivationStateQuantized = 5,
+ kFwCellStateQuantized = 6,
+ kBwCellStateQuantized = 7,
+ kScalingFactors = 8,
+ kProductScalingFactors = 9,
+ kRecoveredCellWeights = 10,
+ kNumTemporaryTensors = 11
+};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 2, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -127,7 +163,8 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -276,45 +313,55 @@ TfLiteStatus CheckLstmTensorDimensions(
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- CheckLstmTensorDimensions(
- context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor,
- kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor,
- kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor,
- kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor,
- kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor,
- kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor,
- kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
- kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
- kFwProjectionBiasTensor);
-
- CheckLstmTensorDimensions(
- context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor,
- kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor,
- kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor,
- kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor,
- kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor,
- kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor,
- kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
- kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
- kBwProjectionBiasTensor);
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckLstmTensorDimensions(
+ context, node, n_input, n_output, n_cell,
+ kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
+ kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
+ kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
+ kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
+ kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
+ kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
+ kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
+ kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
+ kFwProjectionBiasTensor));
+
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckLstmTensorDimensions(
+ context, node, n_input, n_output, n_cell,
+ kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
+ kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
+ kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
+ kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
+ kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
+ kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
+ kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
+ kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
+ kBwProjectionBiasTensor));
// Check if Forward and Backward tensors match along required dimensions.
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
+// Resize the output and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input->dims->size > 1);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -326,6 +373,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input);
+ const TfLiteTensor* bw_input_to_output_weights =
+ GetInput(context, node, kBwInputToOutputWeightsTensor);
+ const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
+ n_input);
+
const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
@@ -333,49 +387,105 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_fw_cell);
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
+ const TfLiteTensor* bw_recurrent_to_output_weights =
+ GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
+ n_bw_cell);
+ const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
+
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell);
+ TF_LITE_ENSURE_OK(
+ context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
+ n_fw_cell));
+
+ // Get (optional) auxiliary inputs and weights.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) &&
+ (fw_aux_input_to_forget_weights != nullptr) &&
+ (fw_aux_input_to_output_weights != nullptr) &&
+ (bw_aux_input_to_cell_weights != nullptr) &&
+ (bw_aux_input_to_forget_weights != nullptr) &&
+ (bw_aux_input_to_output_weights != nullptr)) ||
+ ((fw_aux_input_to_cell_weights == nullptr) &&
+ (fw_aux_input_to_forget_weights == nullptr) &&
+ (fw_aux_input_to_output_weights == nullptr) &&
+ (bw_aux_input_to_cell_weights == nullptr) &&
+ (bw_aux_input_to_forget_weights == nullptr) &&
+ (bw_aux_input_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
+
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ }
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
-
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
+ n_batch * n_fw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
- fw_output_size->data[2] = n_fw_output;
+ fw_output_size->data[2] =
+ params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
- TfLiteIntArray* fw_output_state_size = TfLiteIntArrayCreate(2);
- fw_output_state_size->data[0] = n_batch;
- fw_output_state_size->data[1] = n_fw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
- fw_output_state_size));
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8);
- TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
- fw_cell_size->data[0] = n_batch;
- fw_cell_size->data[1] = n_fw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
-
- // Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
- node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
+ }
+ // Create a scratch buffer tensor.
+ node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ GetTemporary(context, node, kFwScratchBuffer);
fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
+ if (has_aux_input) {
+ TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
+ fw_input_to_input_weights->dims->data[0]);
+ }
const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
fw_scratch_buffer_size->data[0] = n_batch;
@@ -389,61 +499,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
fw_scratch_buffer_size));
// Same for the backward cell.
- const TfLiteTensor* bw_input_to_output_weights =
- GetInput(context, node, kBwInputToOutputWeightsTensor);
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
- n_input);
-
- const TfLiteTensor* bw_recurrent_to_output_weights =
- GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
- n_bw_cell);
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
-
- // Get the pointer to output, output_state and cell_state buffer tensors.
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
-
- // Resize the output, output_state and cell_state tensors.
- TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
- bw_output_size->data[0] = max_time;
- bw_output_size->data[1] = n_batch;
- bw_output_size->data[2] = n_bw_output;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, bw_output, bw_output_size));
-
- TfLiteIntArray* bw_output_state_size = TfLiteIntArrayCreate(2);
- bw_output_state_size->data[0] = n_batch;
- bw_output_state_size->data[1] = n_bw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
- bw_output_state_size));
-
- TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
- bw_cell_size->data[0] = n_batch;
- bw_cell_size->data[1] = n_bw_cell;
TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
+ n_bw_cell));
+
+ // Get the pointer to activation_state and cell_state buffer tensors.
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
+
+ // Resize the output tensors.
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
+ bw_output_size->data[0] = max_time;
+ bw_output_size->data[1] = n_batch;
+ bw_output_size->data[2] = n_bw_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_output, bw_output_size));
+ }
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
+ n_batch * n_bw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
// Create a scratch buffer tensor.
- node->temporaries->data[1] = *(scratch_tensor_index) + 1;
- TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1);
+ node->temporaries->data[kBwScratchBuffer] =
+ *(scratch_tensor_index) + kBwScratchBuffer;
+ TfLiteTensor* bw_scratch_buffer =
+ GetTemporary(context, node, kBwScratchBuffer);
bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
+ if (has_aux_input) {
+ TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
+ bw_input_to_input_weights->dims->data[0]);
+ }
const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
bw_scratch_buffer_size->data[0] = n_batch;
@@ -456,18 +555,153 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
bw_scratch_buffer_size));
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input, aux_input
+ // (if present), activation_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
+
+ node->temporaries->data[kFwActivationStateQuantized] =
+ *scratch_tensor_index + kFwActivationStateQuantized;
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ fw_activation_state_quantized->type = kTfLiteUInt8;
+ fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
+ fw_activation_state->dims)) {
+ TfLiteIntArray* fw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(fw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_activation_state_quantized,
+ fw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kBwActivationStateQuantized] =
+ *scratch_tensor_index + kBwActivationStateQuantized;
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ bw_activation_state_quantized->type = kTfLiteUInt8;
+ bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
+ bw_activation_state->dims)) {
+ TfLiteIntArray* bw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(bw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_activation_state_quantized,
+ bw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kFwCellStateQuantized] =
+ *scratch_tensor_index + kFwCellStateQuantized;
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ fw_cell_state_quantized->type = kTfLiteUInt8;
+ fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
+ fw_cell_state->dims)) {
+ TfLiteIntArray* fw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(fw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, fw_cell_state_quantized,
+ fw_cell_state_quantized_size));
+ }
+ node->temporaries->data[kBwCellStateQuantized] =
+ *scratch_tensor_index + kBwCellStateQuantized;
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ bw_cell_state_quantized->type = kTfLiteUInt8;
+ bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
+ bw_cell_state->dims)) {
+ TfLiteIntArray* bw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(bw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, bw_cell_state_quantized,
+ bw_cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_fw_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
+ }
return kTfLiteOk;
}
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
// Tensors for the forward cell.
const TfLiteTensor* fw_input_to_input_weights =
@@ -509,9 +743,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_projection_bias =
GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
// Tensors for the backward cell.
@@ -554,154 +789,144 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_projection_bias =
GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
- const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
- const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr);
+ // State tensors.
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
- // Index the scratch buffers pointers to the global scratch buffer.
+ // Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
- &context->tensors[node->temporaries->data[0]];
- float* fw_input_gate_scratch = nullptr;
- float* fw_cell_scratch = nullptr;
- float* fw_forget_gate_scratch = nullptr;
- float* fw_output_gate_scratch = nullptr;
- if (fw_use_cifg) {
- fw_cell_scratch = fw_scratch_buffer->data.f;
- fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- } else {
- fw_input_gate_scratch = fw_scratch_buffer->data.f;
- fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_forget_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* fw_input_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f;
- const float* fw_recurrent_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f;
- const float* fw_input_gate_bias_ptr =
- (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f;
- const float* fw_cell_to_input_weights_ptr =
- (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f
- : nullptr;
- const float* fw_cell_to_forget_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr;
- const float* fw_cell_to_output_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr;
- const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr)
- ? nullptr
- : fw_projection_weights->data.f;
- const float* fw_projection_bias_ptr =
- (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f;
-
- // Loop through the sequence.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, fw_input_to_input_weights_ptr,
- fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f,
- fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr,
- fw_recurrent_to_forget_weights->data.f,
- fw_recurrent_to_cell_weights->data.f,
- fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr,
- fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr,
- fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
- fw_cell_bias->data.f, fw_output_gate_bias->data.f,
- fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
- n_fw_cell, n_input, n_fw_output, fw_output_state->data.f,
- fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
- fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
- }
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
- const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
+ GetTemporary(context, node, kFwScratchBuffer);
TfLiteTensor* bw_scratch_buffer =
- &context->tensors[node->temporaries->data[1]];
- float* bw_input_gate_scratch = nullptr;
- float* bw_cell_scratch = nullptr;
- float* bw_forget_gate_scratch = nullptr;
- float* bw_output_gate_scratch = nullptr;
- if (bw_use_cifg) {
- bw_cell_scratch = bw_scratch_buffer->data.f;
- bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- } else {
- bw_input_gate_scratch = bw_scratch_buffer->data.f;
- bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_forget_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* bw_input_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f;
- const float* bw_recurrent_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f;
- const float* bw_input_gate_bias_ptr =
- (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f;
- const float* bw_cell_to_input_weights_ptr =
- (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f
- : nullptr;
- const float* bw_cell_to_forget_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr;
- const float* bw_cell_to_output_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr;
- const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr)
- ? nullptr
- : bw_projection_weights->data.f;
- const float* bw_projection_bias_ptr =
- (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f;
-
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, bw_input_to_input_weights_ptr,
- bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f,
- bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr,
- bw_recurrent_to_forget_weights->data.f,
- bw_recurrent_to_cell_weights->data.f,
- bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr,
- bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr,
- bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
- bw_cell_bias->data.f, bw_output_gate_bias->data.f,
- bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
- n_bw_cell, n_input, n_bw_output, bw_output_state->data.f,
- bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
- bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
+ GetTemporary(context, node, kBwScratchBuffer);
+
+ // (Optional) auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+ // Populate a TfLiteLSTMParams struct for the evaluation functions.
+ TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
+ params->proj_clip, kTfLiteLSTMFullKernel};
+
+ const int bw_output_offset =
+ params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
+ const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
+
+ switch (fw_input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, fw_input_gate_bias,
+ fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ fw_activation_state, fw_cell_state, fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
+ bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
+ bw_aux_input_to_output_weights, bw_input_gate_bias,
+ bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ bw_activation_state, bw_cell_state, actual_bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+
+ TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, fw_input_gate_bias,
+ fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, fw_activation_state_quantized,
+ fw_cell_state_quantized, fw_activation_state, fw_cell_state,
+ fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, bw_input_gate_bias,
+ bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, bw_activation_state_quantized,
+ bw_cell_state_quantized, bw_activation_state, bw_cell_state,
+ actual_bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ fw_input_to_output_weights->type);
+ return kTfLiteError;
}
-
- // Backward step.
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index a18e1bce34..9cc04907e1 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -35,8 +35,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
int sequence_length, bool use_cifg,
bool use_peephole, bool use_projection_weights,
- bool use_projection_bias, float cell_clip,
- float proj_clip,
+ bool use_projection_bias, bool merge_outputs,
+ float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes)
: n_batch_(n_batch),
n_input_(n_input),
@@ -102,10 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_projection_bias_ = AddNullInput();
}
- fw_output_state_ = AddOutput(TensorType_FLOAT32);
- fw_cell_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
-
if (use_cifg) {
bw_input_to_input_weights_ = AddNullInput();
} else {
@@ -161,14 +157,43 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_projection_bias_ = AddNullInput();
}
- bw_output_state_ = AddOutput(TensorType_FLOAT32);
- bw_cell_state_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ fw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ fw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ // Adding the 2 input state tensors.
+ bw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ bw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
+
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
+
+ aux_input_ = AddNullInput();
+ fw_aux_input_to_input_weights_ = AddNullInput();
+ fw_aux_input_to_forget_weights_ = AddNullInput();
+ fw_aux_input_to_cell_weights_ = AddNullInput();
+ fw_aux_input_to_output_weights_ = AddNullInput();
+ bw_aux_input_to_input_weights_ = AddNullInput();
+ bw_aux_input_to_forget_weights_ = AddNullInput();
+ bw_aux_input_to_cell_weights_ = AddNullInput();
+ bw_aux_input_to_output_weights_ = AddNullInput();
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ CreateBidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip,
+ proj_clip, merge_outputs)
.Union());
BuildInterpreter(input_shapes);
}
@@ -259,26 +284,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(bw_projection_bias_, f);
}
- void ResetFwOutputAndCellStates() {
- const int zero_buffer_size = n_fw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(fw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(fw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetBwOutputAndCellStates() {
- const int zero_buffer_size = n_bw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(bw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(bw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -340,13 +345,23 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int bw_projection_weights_;
int bw_projection_bias_;
- int fw_output_;
- int fw_output_state_;
- int fw_cell_state_;
+ int fw_input_activation_state_;
+ int fw_input_cell_state_;
+ int bw_input_activation_state_;
+ int bw_input_cell_state_;
+ int fw_output_;
int bw_output_;
- int bw_output_state_;
- int bw_cell_state_;
+
+ int aux_input_;
+ int fw_aux_input_to_input_weights_;
+ int fw_aux_input_to_forget_weights_;
+ int fw_aux_input_to_cell_weights_;
+ int fw_aux_input_to_output_weights_;
+ int bw_aux_input_to_input_weights_;
+ int bw_aux_input_to_forget_weights_;
+ int bw_aux_input_to_cell_weights_;
+ int bw_aux_input_to_output_weights_;
int n_batch_;
int n_input_;
@@ -368,7 +383,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -417,6 +433,22 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -474,10 +506,6 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
-0.0332076, 0.123838, 0.309777, -0.17621,
-0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -500,34 +528,318 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
+
+// Same as the previous test, yet with a single merged output tensor.
+TEST(LSTMOpTest, BlackBoxTestMergedOutput) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
+
+ float* batch0_start = lstm_input;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int k = 0; k < lstm.sequence_length(); k++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_fw_golden_output + k * lstm.num_fw_outputs(),
+ lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs());
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_bw_golden_output + k * lstm.num_bw_outputs(),
+ lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs());
+ }
+ EXPECT_THAT(lstm.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
// Check reversed inputs.
static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -545,7 +857,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -592,6 +905,22 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -642,10 +971,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
-0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -668,34 +993,154 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
- // Check reversed inputs.
- static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+TEST(LSTMOpTest,
+ BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
+ /*use_peephole=*/true, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302};
+ static float lstm_bw_golden_output[] = {
+ -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
+ 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
+
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -712,7 +1157,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/true, /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -759,6 +1205,22 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights(
@@ -1343,10 +1805,6 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
0.065133, 0.024321, 0.038473, 0.062438
}};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
for (int i = 0; i < lstm.sequence_length(); i++) {
float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
float* batch0_end = batch0_start + lstm.num_inputs();
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index aa24c1f34c..c22a457a71 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdlib>
@@ -20,10 +19,11 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -36,34 +36,84 @@ constexpr int kInputTensor = 0;
constexpr int kFwWeightsTensor = 1;
constexpr int kFwRecurrentWeightsTensor = 2;
constexpr int kFwBiasTensor = 3;
-constexpr int kBwWeightsTensor = 4;
-constexpr int kBwRecurrentWeightsTensor = 5;
-constexpr int kBwBiasTensor = 6;
-// State and output tensors.
-constexpr int kFwHiddenStateTensor = 0;
-constexpr int kFwOutputTensor = 1;
-constexpr int kBwHiddenStateTensor = 2;
-constexpr int kBwOutputTensor = 3;
+constexpr int kFwHiddenStateTensor = 4;
+constexpr int kBwWeightsTensor = 5;
+constexpr int kBwRecurrentWeightsTensor = 6;
+constexpr int kBwBiasTensor = 7;
+constexpr int kBwHiddenStateTensor = 8;
+// Auxiliary inputs.
+constexpr int kAuxInputTensor = 9; // Optional.
+constexpr int kFwAuxWeightsTensor = 10; // Optional.
+constexpr int kBwAuxWeightsTensor = 11; // Optional.
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kFwHiddenStateQuantized = 1,
+ kBwHiddenStateQuantized = 2,
+ kScalingFactors = 3,
+ kAuxInputQuantized = 4,
+ kNumTemporaryTensors = 5
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
+
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* fw_input_weights =
- &context->tensors[node->inputs->data[kFwWeightsTensor]];
- TfLiteTensor* fw_recurrent_weights =
- &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
- TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
- TfLiteTensor* bw_input_weights =
- &context->tensors[node->inputs->data[kBwWeightsTensor]];
- TfLiteTensor* bw_recurrent_weights =
- &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
- TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* fw_input_weights =
+ GetInput(context, node, kFwWeightsTensor);
+ const TfLiteTensor* fw_recurrent_weights =
+ GetInput(context, node, kFwRecurrentWeightsTensor);
+ const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* fw_hidden_state =
+ GetInput(context, node, kFwHiddenStateTensor);
+ const TfLiteTensor* bw_input_weights =
+ GetInput(context, node, kBwWeightsTensor);
+ const TfLiteTensor* bw_recurrent_weights =
+ GetInput(context, node, kBwRecurrentWeightsTensor);
+ const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+ const TfLiteTensor* bw_hidden_state =
+ GetInput(context, node, kBwHiddenStateTensor);
+
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_weights != nullptr) &&
+ (bw_aux_input_weights != nullptr)) ||
+ ((aux_input == nullptr) && (fw_aux_input_weights == nullptr) &&
+ (bw_aux_input_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
@@ -76,77 +126,150 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_bias->dims->data[0]);
TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
bw_bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
- TfLiteTensor* fw_output =
- &context->tensors[node->outputs->data[kFwOutputTensor]];
- TfLiteTensor* bw_output =
- &context->tensors[node->outputs->data[kBwOutputTensor]];
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ // Check that aux_input_weights has the same dimensions (except last) as
+ // the input_weights.
+ TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
+ TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ fw_aux_input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ bw_aux_input_weights->dims->data[1]);
+ }
- // Resize hidden states.
- TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- fw_hidden_state_size_array->data[0] = batch_size;
- fw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* fw_hidden_state =
- &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
- fw_hidden_state_size_array));
+ const bool is_hybrid_op =
+ (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
- TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- bw_hidden_state_size_array->data[0] = batch_size;
- bw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* bw_hidden_state =
- &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
- bw_hidden_state_size_array));
+ if (is_hybrid_op) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (has_aux_input) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ // No need to create a temporary tensor for the non-existent aux_input.
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
+ }
+
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
- // Mark hidden states as a persistent tensor.
- fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
+ node->temporaries->data[kFwHiddenStateQuantized] =
+ *scratch_tensor_index + kFwHiddenStateQuantized;
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ fw_hidden_state_quantized->type = kTfLiteUInt8;
+ fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
+ fw_hidden_state->dims)) {
+ TfLiteIntArray* fw_hidden_state_quantized_size =
+ TfLiteIntArrayCopy(fw_hidden_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_hidden_state_quantized,
+ fw_hidden_state_quantized_size));
+ }
+
+ node->temporaries->data[kBwHiddenStateQuantized] =
+ *scratch_tensor_index + kBwHiddenStateQuantized;
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ bw_hidden_state_quantized->type = kTfLiteUInt8;
+ bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
+ bw_hidden_state->dims)) {
+ TfLiteIntArray* bw_hidden_state_quantized_size =
+ TfLiteIntArrayCopy(bw_hidden_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_hidden_state_quantized,
+ bw_hidden_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors of quantization.
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
+ }
// Resize outputs.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = batch_size;
fw_output_size_array->data[1] = max_time;
- fw_output_size_array->data[2] = fw_num_units;
+ fw_output_size_array->data[2] =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
- TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
- bw_output_size_array->data[0] = batch_size;
- bw_output_size_array->data[1] = max_time;
- bw_output_size_array->data[2] = bw_num_units;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_output, bw_output_size_array));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
+ bw_output_size_array->data[0] = batch_size;
+ bw_output_size_array->data[1] = max_time;
+ bw_output_size_array->data[2] = bw_num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
+ bw_output_size_array));
+ }
return kTfLiteOk;
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* fw_input_weights =
- &context->tensors[node->inputs->data[kFwWeightsTensor]];
- TfLiteTensor* fw_recurrent_weights =
- &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
- TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
- TfLiteTensor* fw_hidden_state =
- &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
- TfLiteTensor* fw_output =
- &context->tensors[node->outputs->data[kFwOutputTensor]];
-
- TfLiteTensor* bw_input_weights =
- &context->tensors[node->inputs->data[kBwWeightsTensor]];
- TfLiteTensor* bw_recurrent_weights =
- &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
- TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
- TfLiteTensor* bw_hidden_state =
- &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
- TfLiteTensor* bw_output =
- &context->tensors[node->outputs->data[kBwOutputTensor]];
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
+ const TfLiteTensor* bw_aux_input_weights,
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
const int fw_num_units = fw_input_weights->dims->data[0];
const float* fw_bias_ptr = fw_bias->data.f;
@@ -158,45 +281,258 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const float* bw_input_weights_ptr = bw_input_weights->data.f;
const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
+ const float* fw_aux_input_weights_ptr = (fw_aux_input_weights != nullptr)
+ ? fw_aux_input_weights->data.f
+ : nullptr;
+ const float* bw_aux_input_weights_ptr = (bw_aux_input_weights != nullptr)
+ ? bw_aux_input_weights->data.f
+ : nullptr;
+
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
- input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
- fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
+ input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
+ fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
+ input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
- input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
- bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
+ input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
+ bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
+ input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
}
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
+ const TfLiteTensor* aux_bw_input_weights,
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
+ TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
+ const int batch_size = input->dims->data[0];
+ const int max_time = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+
+ const int fw_num_units = fw_input_weights->dims->data[0];
+ const float* fw_bias_ptr = fw_bias->data.f;
+ const int8_t* fw_input_weights_ptr =
+ reinterpret_cast<const int8_t*>(fw_input_weights->data.uint8);
+ float fw_input_weights_scale = fw_input_weights->params.scale;
+ const int8_t* fw_recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(fw_recurrent_weights->data.uint8);
+ float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
+
+ const int bw_num_units = bw_input_weights->dims->data[0];
+ const float* bw_bias_ptr = bw_bias->data.f;
+ const int8_t* bw_input_weights_ptr =
+ reinterpret_cast<const int8_t*>(bw_input_weights->data.uint8);
+ float bw_input_weights_scale = bw_input_weights->params.scale;
+ const int8_t* bw_recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(bw_recurrent_weights->data.uint8);
+ float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
+
+ // Set the auxiliary pointers and scales if needed.
+ int8_t* aux_fw_input_weights_ptr = nullptr;
+ float aux_fw_input_weights_scale = 0.0f;
+ int8_t* aux_bw_input_weights_ptr = nullptr;
+ float aux_bw_input_weights_scale = 0.0f;
+ int8_t* aux_quantized_input_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_fw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_fw_input_weights->data.uint8);
+ aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
+ aux_bw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_bw_input_weights->data.uint8);
+ aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
+ aux_quantized_input_ptr = reinterpret_cast<int8_t*>(aux_input_quantized);
+ }
+
+ // Initialize temporary storage for quantized values.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* fw_quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(fw_hidden_state_quantized->data.uint8);
+ int8_t* bw_quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ fw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ bw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* fw_input_weights =
+ GetInput(context, node, kFwWeightsTensor);
+ const TfLiteTensor* fw_recurrent_weights =
+ GetInput(context, node, kFwRecurrentWeightsTensor);
+ const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* bw_input_weights =
+ GetInput(context, node, kBwWeightsTensor);
+ const TfLiteTensor* bw_recurrent_weights =
+ GetInput(context, node, kBwRecurrentWeightsTensor);
+ const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+
+ // Get auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ TfLiteTensor* fw_hidden_state =
+ GetVariableInput(context, node, kFwHiddenStateTensor);
+ TfLiteTensor* bw_hidden_state =
+ GetVariableInput(context, node, kBwHiddenStateTensor);
+
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
+
+ switch (fw_input_weights->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(input, fw_input_weights, fw_recurrent_weights, fw_bias,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, fw_hidden_state, fw_output, bw_hidden_state,
+ bw_output);
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* aux_input_quantized =
+ (aux_input != nullptr)
+ ? GetTemporary(context, node, kAuxInputQuantized)
+ : nullptr;
+
+ return EvalHybrid(input, fw_input_weights, fw_recurrent_weights, fw_bias,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, scaling_factors, input_quantized,
+ aux_input_quantized, fw_hidden_state_quantized,
+ fw_hidden_state, fw_output, bw_hidden_state_quantized,
+ bw_hidden_state, bw_output);
+ }
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace bidirectional_sequence_rnn
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_rnn::Prepare,
- bidirectional_sequence_rnn::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
+ bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 911b108eaa..f555c472f5 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size)
+ int bw_units, int input_size, bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -664,26 +664,40 @@ class BidirectionalRNNOpModel : public SingleOpModel {
fw_weights_ = AddInput(TensorType_FLOAT32);
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32);
- fw_hidden_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
+ fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32);
- bw_hidden_state_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
+
+ aux_input_ = AddNullInput();
+ aux_fw_weights_ = AddNullInput();
+ aux_bw_weights_ = AddNullInput();
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_SequenceRNNOptions,
- CreateSequenceRNNOptions(builder_, /*time_major=*/false,
- ActivationFunctionType_RELU)
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, /*time_major=*/false,
+ ActivationFunctionType_RELU, merge_outputs)
.Union());
BuildInterpreter({
{batches_, sequence_len_, input_size_}, // input
{fw_units_, input_size_}, // fw_weights
{fw_units_, fw_units_}, // fw_recurrent_weights
{fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
{bw_units_, input_size_}, // bw_weights
{bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_} // bw_bias
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -719,19 +733,6 @@ class BidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenStates() {
- const int fw_zero_buffer_size = fw_units_ * batches_;
- std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]);
- memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float));
- PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(),
- fw_zero_buffer.get() + fw_zero_buffer_size);
- const int bw_zero_buffer_size = bw_units_ * batches_;
- std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]);
- memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float));
- PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(),
- bw_zero_buffer.get() + bw_zero_buffer_size);
- }
-
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
@@ -753,6 +754,9 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_bias_;
int bw_hidden_state_;
int bw_output_;
+ int aux_input_;
+ int aux_fw_weights_;
+ int aux_bw_weights_;
int batches_;
int sequence_len_;
@@ -766,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -774,7 +778,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
float* batch_end = batch_start + input_sequence_size;
@@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with merged outputs.
+TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*merge_outputs=*/true);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ float* batch_start = rnn_input;
+ float* batch_end = batch_start + input_sequence_size;
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(input_sequence_size, batch_start, batch_end);
+
+ rnn.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int bid = 0; bid < rnn.num_batches(); bid++) {
+ for (int step = 0; step < rnn.sequence_len(); step++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_fw_output + rnn.num_fw_units() * step,
+ rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_bw_output + rnn.num_bw_units() * step,
+ rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
+ }
+ }
+ EXPECT_THAT(rnn.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
// Check that if the input sequence is reversed the outputs are the same just
// forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -813,8 +853,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
// Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
// following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
for (int i = 0; i < rnn.sequence_len(); i++) {
@@ -853,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,
@@ -880,8 +918,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
const int output_sequence_size = output_size * rnn.sequence_len();
const int num_examples = 64;
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 60770ca0aa..a7972140ac 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -14,8 +14,9 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include <complex>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) {
[](FromT a) { return static_cast<ToT>(a); });
}
+template <typename ToT>
+void copyCast(const std::complex<float>* in, ToT* out, int num_elements) {
+ std::transform(in, in + num_elements, out, [](std::complex<float> a) {
+ return static_cast<ToT>(std::real(a));
+ });
+}
+
+template <>
+void copyCast(const std::complex<float>* in, std::complex<float>* out,
+ int num_elements) {
+ std::transform(in, in + num_elements, out,
+ [](std::complex<float> a) { return a; });
+}
+
template <typename FromT>
TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
int num_elements) {
@@ -72,6 +87,10 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
case kTfLiteBool:
copyCast(in, out->data.b, num_elements);
break;
+ case kTfLiteComplex64:
+ copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
+ num_elements);
+ break;
default:
// Unsupported type.
return kTfLiteError;
@@ -95,6 +114,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(input->data.f, output, num_elements);
case kTfLiteBool:
return copyToTensor(input->data.b, output, num_elements);
+ case kTfLiteComplex64:
+ return copyToTensor(
+ reinterpret_cast<std::complex<float>*>(input->data.c64), output,
+ num_elements);
default:
// Unsupported type.
return kTfLiteError;
diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc
index 53e2000737..954f998206 100644
--- a/tensorflow/contrib/lite/kernels/cast_test.cc
+++ b/tensorflow/contrib/lite/kernels/cast_test.cc
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <complex>
+
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
@@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) {
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
}
+TEST(CastOpModel, CastComplex64ToFloat) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
+}
+
+TEST(CastOpModel, CastFloatToComplex64) {
+ CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToInt) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int>(m.output()),
+ ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(CastOpModel, CastIntToComplex64) {
+ CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToComplex64) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f),
+ std::complex<float>(6.0f, 16.0f)}));
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index b948334b6d..3926af5b97 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -23,6 +23,7 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace comparisons {
+namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
@@ -56,23 +57,131 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
-#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
- requires_broadcast \
- ? reference_ops::Broadcast##opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output)) \
- : reference_ops::opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+// TODO(ruic): optimize macros below to using template functions.
+#define TF_LITE_QUANTIZE_COMPARISON(opname) \
+ void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \
+ const TfLiteTensor* input1, \
+ const TfLiteTensor* input2, TfLiteTensor* output, \
+ bool requires_broadcast) { \
+ if (input1->type == kTfLiteUInt8) { \
+ auto input1_offset = -input1->params.zero_point; \
+ auto input2_offset = -input2->params.zero_point; \
+ const int left_shift = 8; \
+ \
+ int32 input1_multiplier; \
+ int input1_shift; \
+ QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
+ &input1_multiplier, &input1_shift); \
+ int32 input2_multiplier; \
+ int input2_shift; \
+ QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
+ &input2_multiplier, &input2_shift); \
+ \
+ ComparisonParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = input2_shift; \
+ if (requires_broadcast) { \
+ reference_ops::Broadcast4DSlow##opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
+ } else { \
+ reference_ops::opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
+ } \
+ } \
+ }
+TF_LITE_QUANTIZE_COMPARISON(Equal);
+TF_LITE_QUANTIZE_COMPARISON(NotEqual);
+TF_LITE_QUANTIZE_COMPARISON(Greater);
+TF_LITE_QUANTIZE_COMPARISON(GreaterEqual);
+TF_LITE_QUANTIZE_COMPARISON(Less);
+TF_LITE_QUANTIZE_COMPARISON(LessEqual);
+#undef TF_LITE_QUANTIZE_COMPARISON
+
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ { \
+ ComparisonParams op_params; \
+ requires_broadcast \
+ ? reference_ops::Broadcast4DSlow##opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)) \
+ : reference_ops::opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
+ }
+
+TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantizedEqual(context, node, input1, input2, output,
+ requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int|uint8",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// TODO(renjieliu): Refactor the logic to avoid duplications.
+TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantizedNotEqual(context, node, input1, input2, output,
+ requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int|uint8",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
- // TODO(renjieliu): Support quantized data.
switch (input1->type) {
case kTfLiteFloat32:
TF_LITE_COMPARISON(float, Greater, requires_broadcast);
@@ -83,9 +192,14 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64:
TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
break;
+ case kTfLiteUInt8:
+ EvalQuantizedGreater(context, node, input1, input2, output,
+ requires_broadcast);
+ break;
default:
context->ReportError(context,
- "Does not support type other than float|int");
+ "Does not support type %d, requires float|int|uint8",
+ input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -96,7 +210,6 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
- // TODO(renjieliu): Support quantized data.
switch (input1->type) {
case kTfLiteFloat32:
TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
@@ -107,9 +220,14 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64:
TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
break;
+ case kTfLiteUInt8:
+ EvalQuantizedGreaterEqual(context, node, input1, input2, output,
+ requires_broadcast);
+ break;
default:
context->ReportError(context,
- "Does not support type other than float|int");
+ "Does not support type %d, requires float|int|uint8",
+ input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -120,7 +238,6 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
- // TODO(renjieliu): Support quantized data.
switch (input1->type) {
case kTfLiteFloat32:
TF_LITE_COMPARISON(float, Less, requires_broadcast);
@@ -131,9 +248,14 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64:
TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
break;
+ case kTfLiteUInt8:
+ EvalQuantizedLess(context, node, input1, input2, output,
+ requires_broadcast);
+ break;
default:
context->ReportError(context,
- "Does not support type other than float|int");
+ "Does not support type %d, requires float|int|uint8",
+ input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -144,7 +266,6 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
- // TODO(renjieliu): Support quantized data.
switch (input1->type) {
case kTfLiteFloat32:
TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
@@ -155,16 +276,35 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64:
TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
break;
+ case kTfLiteUInt8:
+ EvalQuantizedLessEqual(context, node, input1, input2, output,
+ requires_broadcast);
+ break;
default:
context->ReportError(context,
- "Does not support type other than float|int");
+ "Does not support type %d, requires float|int|uint8",
+ input1->type);
return kTfLiteError;
}
return kTfLiteOk;
}
+} // namespace
} // namespace comparisons
+TfLiteRegistration* Register_EQUAL() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_NOT_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::NotEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_GREATER() {
static TfLiteRegistration r = {nullptr, nullptr,
comparisons::ComparisonPrepare,
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 835d238d36..04c8bf2e30 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -21,21 +21,29 @@ limitations under the License.
namespace tflite {
namespace {
-using ::testing::ElementsAreArray;
+using ::testing::ElementsAre;
-class GreaterOpModel : public SingleOpModel {
+class ComparisonOpModel : public SingleOpModel {
public:
- GreaterOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
+ ComparisonOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type, BuiltinOperator op) {
input1_ = AddInput(input_type);
input2_ = AddInput(input_type);
output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
- CreateGreaterOptions(builder_).Union());
+ ConfigureBuiltinOp(op);
BuildInterpreter({input1_shape, input2_shape});
}
+ ComparisonOpModel(const TensorData& input1, const TensorData& input2,
+ TensorType input_type, BuiltinOperator op) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(TensorType_BOOL);
+ ConfigureBuiltinOp(op);
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
int input1() { return input1_; }
int input2() { return input2_; }
@@ -46,245 +54,510 @@ class GreaterOpModel : public SingleOpModel {
int input1_;
int input2_;
int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_EqualOptions,
+ CreateEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_NOT_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_NotEqualOptions,
+ CreateNotEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS: {
+ SetBuiltinOp(op, BuiltinOptions_LessOptions,
+ CreateLessOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
};
-TEST(ComparisonsTest, GreaterFloat) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+TEST(ComparisonsTest, EqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterInt) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcast) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcastTwoD) {
- GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false,
+ false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class GreaterEqualOpModel : public SingleOpModel {
- public:
- GreaterEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
- BuiltinOptions_GreaterEqualOptions,
- CreateGreaterEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
+TEST(ComparisonsTest, NotEqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
- int input1() { return input1_; }
- int input2() { return input2_; }
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+TEST(ComparisonsTest, NotEqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
- private:
- int input1_;
- int input2_;
- int output_;
-};
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, true, true, true, true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
+
+TEST(ComparisonsTest, GreaterFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
TEST(ComparisonsTest, GreaterEqualFloat) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualInt) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcast) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
- GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessOpModel : public SingleOpModel {
- public:
- LessOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape, TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions,
- CreateLessOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
TEST(ComparisonsTest, LessFloat) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessInt) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcast) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcastTwoD) {
- LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessEqualOpModel : public SingleOpModel {
- public:
- LessEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
- CreateLessEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
-
TEST(ComparisonsTest, LessEqualFloat) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualInt) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcast) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
- LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
+
+TEST(QuantizedComparisonsTest, EqualQuantized) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1, 9, 7, 3});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false));
+}
+
+TEST(QuantizedComparisonsTest, NotEqualQuantized) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_NOT_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1, 9, 7, 3});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1, 2, 7, 0});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true));
+}
+
+TEST(ComparisonsTest, GreaterQuantized) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_GREATER);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1, 9, 7, 3});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1, 2, 6, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+}
+
+TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
+ TensorType_UINT8, BuiltinOperator_GREATER);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+}
+
+TEST(ComparisonsTest, GreaterEqualQuantized) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_GREATER_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1, 9, 7, 3});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1, 2, 6, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
+}
+
+TEST(ComparisonsTest, LessQuantized) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_LESS);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1, 9, 7, 3});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1, 2, 6, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true));
+}
+
+TEST(ComparisonsTest, LessEqualQuantized) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_LESS_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1, 9, 7, 3});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1, 2, 6, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+}
+
+TEST(ComparisonsTest, QuantizedEqualWithBroadcast) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax},
+ {TensorType_UINT8, {}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {20, 2, 7, 8, 11, 20});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, false, false, false, false))
+ << "With shape number " << i;
+ }
+}
+
+TEST(ComparisonsTest, QuantizedNotEqualWithBroadcast) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax},
+ {TensorType_UINT8, {}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_NOT_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {20, 2, 7, 8, 11, 20});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, true, true, true, true))
+ << "With shape number " << i;
+ }
+}
+
+TEST(ComparisonsTest, QuantizedGreaterWithBroadcast) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax},
+ {TensorType_UINT8, {}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_GREATER);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {20, 2, 7, 8, 11, 20});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {8});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, false, true, true))
+ << "With shape number " << i;
+ }
+}
+
+TEST(ComparisonsTest, QuantizedGreaterEqualWithBroadcast) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax},
+ {TensorType_UINT8, {}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_GREATER_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {20, 2, 7, 8, 11, 20});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {8});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, true))
+ << "With shape number " << i;
+ }
+}
+
+TEST(ComparisonsTest, QuantizedLessWithBroadcast) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax},
+ {TensorType_UINT8, {}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_LESS);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {20, 2, 7, 8, 11, 20});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {8});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, false))
+ << "With shape number " << i;
+ }
+}
+
+TEST(ComparisonsTest, QuantizedLessEqualWithBroadcast) {
+ const float kMin = -1.f;
+ const float kMax = 128.f;
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax},
+ {TensorType_UINT8, {}, kMin, kMax},
+ TensorType_UINT8, BuiltinOperator_LESS_EQUAL);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {20, 2, 7, 8, 11, 20});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {8});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, true, false, false))
+ << "With shape number " << i;
+ }
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 45ea8d0049..7ad3399ffd 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -58,7 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, t0->dims->size <= 4);
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
TF_LITE_ENSURE(context,
- input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+ input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
+ input_type == kTfLiteInt64);
// Output dimensions will match input dimensions, except 'axis', which
// will be the sum of inputs
@@ -99,20 +100,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function.
-#define TF_LITE_CONCATENATION(type, scalar) \
- VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
- type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \
- RemapDim(NumDimensions(output), axis), all_inputs.data(), \
- all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
- GetTensorDims(output))
-
-#define TF_LITE_CONCATENATION_QUANTIZED(type) \
- VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
- type::Concatenation( \
- RemapDim(NumDimensions(output), axis), all_inputs.data(), \
- all_inputs.dims(), all_inputs.zero_point(), all_inputs.scale(), \
- node->inputs->size, GetTensorData<uint8>(output), GetTensorDims(output), \
- output->params.zero_point, output->params.scale)
+#define TF_LITE_CONCATENATION(type, scalar) \
+ { \
+ VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
+ tflite::ConcatenationParams op_params; \
+ op_params.axis = axis; \
+ op_params.inputs_count = node->inputs->size; \
+ type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
+ GetTensorShape(output), \
+ GetTensorData<scalar>(output)); \
+ }
+
+#define TF_LITE_CONCATENATION_QUANTIZED(type) \
+ { \
+ VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
+ tflite::ConcatenationParams op_params; \
+ op_params.axis = axis; \
+ op_params.input_zeropoint = all_inputs.zero_point(); \
+ op_params.input_scale = all_inputs.scale(); \
+ op_params.inputs_count = node->inputs->size; \
+ op_params.output_zeropoint = output->params.zero_point; \
+ op_params.output_scale = output->params.scale; \
+ type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
+ all_inputs.data(), GetTensorShape(output), \
+ GetTensorData<uint8>(output)); \
+ }
switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
@@ -122,6 +134,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_CONCATENATION(optimized_ops, float);
}
break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, int32);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, int32);
+ }
+ break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
@@ -129,6 +148,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
}
break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, int64_t);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, int64_t);
+ }
+ break;
+
default:
context->ReportError(context,
"Only float32 and uint8 are currently supported.");
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 2b7e455e3e..dbcadbee14 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -21,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
@@ -31,6 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/kernels/padding.h"
@@ -61,6 +61,8 @@ struct OpData {
// memory buffers.
int im2col_id = kTensorNotAllocated;
int hwcn_weights_id = kTensorNotAllocated;
+ int input_quantized_id = kTensorNotAllocated;
+ int scaling_factors_id = kTensorNotAllocated;
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
@@ -75,6 +77,8 @@ struct OpData {
// of the allocated temporaries.
int32_t im2col_index;
int32_t hwcn_weights_index;
+ int32_t input_quantized_index;
+ int32_t scaling_factors_index;
bool need_hwcn_weights;
bool have_weights_been_transposed;
bool need_im2col;
@@ -82,6 +86,18 @@ struct OpData {
bool run_multithreaded_kernel;
};
+inline PaddingType RuntimePaddingType(TfLitePadding padding) {
+ switch (padding) {
+ case TfLitePadding::kTfLitePaddingSame:
+ return PaddingType::kSame;
+ case TfLitePadding::kTfLitePaddingValid:
+ return PaddingType::kValid;
+ case TfLitePadding::kTfLitePaddingUnknown:
+ default:
+ return PaddingType::kNone;
+ }
+}
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to use as scratch space for im2col, and
@@ -126,6 +142,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ const bool is_hybrid =
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8);
+
int filter_width = filter->dims->data[2];
int filter_height = filter->dims->data[1];
@@ -134,7 +153,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
data->need_im2col =
(params->stride_width != 1 || params->stride_height != 1 ||
- filter_width != 1 || filter_height != 1);
+ params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1 || filter_width != 1 ||
+ filter_height != 1);
// If we're using the optimized multithreaded EigenTensor implementation of
// convolution, it expects the filter weights to be transposed compared to
// the normal TF Lite buffer format. Typical TF Lite weights are
@@ -144,8 +165,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// buffer to store the results.
// This path is only used for float processing, so only create the buffer if
// we're running with that data type.
- data->need_hwcn_weights =
- (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel);
+ data->need_hwcn_weights = (input->type == kTfLiteFloat32 &&
+ data->run_multithreaded_kernel && !is_hybrid);
int temporaries_count = 0;
if (data->need_im2col) {
@@ -163,6 +184,25 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
++temporaries_count;
}
+ if (is_hybrid) {
+ // Allocate tensor to store the on-the-fly quantized inputs.
+ data->input_quantized_index = temporaries_count;
+ if (data->input_quantized_id == kTensorNotAllocated) {
+ TF_LITE_ENSURE_OK(
+ context, context->AddTensors(context, 1, &data->input_quantized_id));
+ }
+ ++temporaries_count;
+
+ // Allocate tensor to store the quantization params computed during
+ // on-the-fly input quantization.
+ data->scaling_factors_index = temporaries_count;
+ if (data->scaling_factors_id == kTensorNotAllocated) {
+ TF_LITE_ENSURE_OK(
+ context, context->AddTensors(context, 1, &data->scaling_factors_id));
+ }
+ ++temporaries_count;
+ }
+
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(temporaries_count);
@@ -173,13 +213,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- data->run_multithreaded_kernel = context->recommended_num_threads != 1;
-
- TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
-
- bool hasBias = node->inputs->size == 3;
+ bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
- TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2);
+ TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
@@ -192,29 +228,40 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]);
// Check types. (We assume that UINT8 refers to quantized tensors)
- TfLiteType data_type = input->type;
+ TfLiteType input_type = input->type;
TF_LITE_ENSURE(context,
- data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
- TF_LITE_ENSURE_EQ(context, output->type, data_type);
- TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, input_type);
TfLiteTensor* bias = nullptr;
// TODO(ahentz): At this point the optimized versions require 'bias'. We can
// either change that or document that convolution requires it.
- TF_LITE_ENSURE(context, hasBias);
+ TF_LITE_ENSURE(context, has_bias);
- if (hasBias) {
+ if (has_bias) {
bias = &context->tensors[node->inputs->data[2]];
- if (data_type == kTfLiteUInt8) {
+ if (input_type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else {
- TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ TF_LITE_ENSURE_EQ(context, bias->type, input_type);
}
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
+ const bool is_hybrid =
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8);
+
+ data->run_multithreaded_kernel = context->recommended_num_threads != 1;
+ // Hybrid kernels don't support multithreading yet.
+ if (is_hybrid) {
+ data->run_multithreaded_kernel = false;
+ }
+
+ TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
+
+ int channels_in = filter->dims->data[3];
int channels_out = filter->dims->data[0];
int width = input->dims->data[2];
int height = input->dims->data[1];
@@ -224,38 +271,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto computeOutSize = [padding](int imageSize, int filterSize, int stride,
- int dilationRate) -> int {
- int effectiveFilterSize = (filterSize - 1) * dilationRate + 1;
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - effectiveFilterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int outWidth = computeOutSize(width, filter_width, params->stride_width,
- params->dilation_width_factor);
- int outHeight = computeOutSize(height, filter_height, params->stride_height,
- params->dilation_height_factor);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
+ int out_height =
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
data->padding.height =
ComputePadding(params->stride_height, params->dilation_height_factor,
- height, filter_height, outHeight);
+ height, filter_height, out_height);
data->padding.width =
ComputePadding(params->stride_width, params->dilation_width_factor, width,
- filter_width, outWidth);
+ filter_width, out_width);
- TF_LITE_ENSURE(context, hasBias);
+ TF_LITE_ENSURE(context, has_bias);
- // Note that quantized inference requires that all tensors have their
+ // Note that full fixed-point inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
+ if (input_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
- QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
- &data->output_shift);
+
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
&data->output_activation_max);
@@ -263,8 +313,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = batches;
- output_size->data[1] = outHeight;
- output_size->data[2] = outWidth;
+ output_size->data[1] = out_height;
+ output_size->data[2] = out_width;
output_size->data[3] = channels_out;
auto output_status = context->ResizeTensor(context, output, output_size);
@@ -283,7 +333,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* im2col =
&context->tensors[node->temporaries->data[data->im2col_index]];
- im2col->type = data_type;
+ im2col->type = input->type;
+ if (is_hybrid) {
+ im2col->type = kTfLiteUInt8;
+ }
im2col->allocation_type = kTfLiteArenaRw;
auto im2col_status = context->ResizeTensor(context, im2col, im2col_size);
if (im2col_status != kTfLiteOk) return im2col_status;
@@ -303,19 +356,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* hwcn_weights =
&context->tensors[node->temporaries->data[data->hwcn_weights_index]];
- hwcn_weights->type = data_type;
- hwcn_weights->allocation_type = kTfLiteDynamic;
- // Make sure we release any previous allocations before we reallocate.
- // TODO(petewarden): Persistent arenas would be a better fit for this, but
- // they aren't fully implemented yet.
- if (hwcn_weights->data.raw) {
- free(hwcn_weights->data.raw);
- hwcn_weights->data.raw = nullptr;
- }
+ hwcn_weights->type = input_type;
+ hwcn_weights->allocation_type = kTfLiteArenaRwPersistent;
- // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and
- // ResizeTensor will actually allocate space for it. The would be more
- // efficient if we placed hwcn_weights_status in the persistent arena.
auto hwcn_weights_status =
context->ResizeTensor(context, hwcn_weights, hwcn_weights_size);
if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status;
@@ -325,6 +368,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->have_weights_been_transposed = false;
}
+ if (is_hybrid) {
+ node->temporaries->data[data->input_quantized_index] =
+ data->input_quantized_id;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, data->input_quantized_index);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ node->temporaries->data[data->scaling_factors_index] =
+ data->scaling_factors_id;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, data->scaling_factors_index);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ // Only one scale factor per batch is typically necessary. See optimized
+ // implementation for why we need to allocate for the height of the inputs
+ // flattened to 2D.
+ scaling_factors_size->data[0] = NumElements(input) / channels_in;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+
return kTfLiteOk;
}
@@ -340,34 +413,70 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- switch (kernel_type) {
- case kReference:
+ KernelType effective_kernel_type;
+ if ((kernel_type == kMultithreadOptimized ||
+ kernel_type == kCblasOptimized) &&
+ (params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1)) {
+ // kMultithreadOptimized and kCblasOptimized do not support dilation.
+ // Therefore, fallback to optimized.
+ effective_kernel_type = kGenericOptimized;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ switch (effective_kernel_type) {
+ case kReference: {
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
reference_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output), GetTensorData<uint8_t>(im2col),
- GetTensorDims(im2col), gemm_context);
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output),
+ GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
break;
+ }
case kGenericOptimized:
case kMultithreadOptimized:
- case kCblasOptimized:
+ case kCblasOptimized: {
// There is only one optimized implementation for Quantized Conv.
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
optimized_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output), GetTensorData<uint8_t>(im2col),
- GetTensorDims(im2col), gemm_context);
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output),
+ GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
break;
+ }
}
}
@@ -377,42 +486,46 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
KernelType effective_kernel_type;
- if (((kernel_type == kMultithreadOptimized) ||
- (kernel_type == kCblasOptimized)) &&
- ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1))) {
+ if ((kernel_type == kMultithreadOptimized ||
+ kernel_type == kCblasOptimized) &&
+ (params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1)) {
// kMultithreadOptimized and kCblasOptimized do not support dilation.
// Therefore, fallback to optimized.
effective_kernel_type = kGenericOptimized;
} else {
effective_kernel_type = kernel_type;
}
+ ConvParams op_params;
+ op_params.padding_type = RuntimePaddingType(params->padding);
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
switch (effective_kernel_type) {
case kReference: {
- reference_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ reference_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kGenericOptimized: {
- optimized_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ optimized_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kMultithreadOptimized: {
@@ -423,24 +536,84 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
filter_data = GetTensorData<float>(filter);
}
multithreaded_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input), filter_data,
- GetTensorDims(filter), GetTensorData<float>(bias),
- GetTensorDims(bias), params->stride_width, params->stride_height,
- data->padding.width, data->padding.height, params->padding,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ *eigen_support::GetThreadPoolDevice(context), op_params,
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), filter_data, GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kCblasOptimized: {
- cblas_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ cblas_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
+ break;
+ }
+ }
+}
+
+template <KernelType kernel_type>
+void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
+ TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ const int input_size = NumElements(input) / SizeOfDimension(input, 0);
+ const int batch_size = SizeOfDimension(input, 0);
+
+ const TfLiteTensor* input_quantized =
+ GetTemporary(context, node, data->input_quantized_index);
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr =
+ GetTemporary(context, node, data->scaling_factors_index)->data.f;
+
+ // Per-batch input quantization for higher accuracy.
+ for (int b = 0; b < batch_size; ++b) {
+ float unused_min, unused_max;
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input->data.f + offset, input_size, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= filter->params.scale;
+ }
+
+ int8_t* im2col_ptr = nullptr;
+ if (im2col != nullptr) {
+ im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ }
+ int8_t* filter_ptr = reinterpret_cast<int8_t*>(filter->data.uint8);
+
+ switch (kernel_type) {
+ case kReference:
+ case kGenericOptimized:
+ case kMultithreadOptimized:
+ case kCblasOptimized: {
+ // There is only one implementation for hybrid kernel. Note
+ // this does not make use of gemmlowp nor supports multithreading.
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ optimized_ops::HybridConv(
+ op_params, scaling_factors_ptr, GetTensorShape(input),
+ quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr,
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(output), GetTensorData<float>(output),
+ GetTensorShape(im2col), im2col_ptr);
break;
}
}
@@ -454,9 +627,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
- bool hasBias = node->inputs->size == 3;
+ bool has_bias = node->inputs->size == 3;
TfLiteTensor* bias =
- hasBias ? &context->tensors[node->inputs->data[2]] : nullptr;
+ has_bias ? &context->tensors[node->inputs->data[2]] : nullptr;
TfLiteTensor* im2col =
data->need_im2col
? &context->tensors[node->temporaries->data[data->im2col_index]]
@@ -475,7 +648,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// separate ops to avoid dispatch overhead here.
switch (input->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
- if (data->run_multithreaded_kernel) {
+ if (filter->type == kTfLiteUInt8) {
+ EvalHybrid<kernel_type>(context, node, params, data, input, filter,
+ bias, im2col, hwcn_weights, output);
+ } else if (data->run_multithreaded_kernel) {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
im2col, hwcn_weights, output);
} else {
@@ -488,7 +664,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bias, im2col, hwcn_weights, output);
break;
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 0dcfc826fd..f7e6f083ed 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -64,12 +64,6 @@ class BaseConvolutionOpModel : public SingleOpModel {
}
output_ = AddOutput(output);
- if (input.type != TensorType_FLOAT32) {
- // The following is required by quantized inference. It is the unittest's
- // responsibility to make sure the output scale falls into the correct
- // range.
- CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
- }
SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
CreateConv2DOptions(
@@ -148,6 +142,128 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) {
}));
}
+// This test's output is equivalent to the SimpleTestFloat32
+// because we break each input into two channels, each with half of the value,
+// while keeping the filters for each channel equivalent.
+//
+// 2 * (A/2) * B = A * B, where the left side is this new test.
+TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {3, 2, 2, 2}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+ m.SetFilter({
+ 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter
+ -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter
+ -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ }));
+}
+
+TEST_P(ConvolutionOpTest, InputAndFilterSameWidthHeight) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {1, 2, 4, 1}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // row = 1
+ -1, -1, 1, 1, // row = 2
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 34}));
+}
+
+TEST_P(ConvolutionOpTest, PointwiseFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {1, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ // First batch
+ 1.5, 1.5, 1.5, 1.5, // row = 1
+ 3., 3., 3., 3., // row = 2
+ // Second batch
+ 1.5, 3., 4.5, 6., // row = 1
+ 1.5, 3., 4.5, 6., // row = 2
+ }));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {2, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ }));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -376,6 +492,65 @@ TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357}));
}
+TEST_P(ConvolutionOpTest, SimpleTestFloatWithDilation) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const int dilation_width_factor = 3;
+ const int dilation_height_factor = 3;
+ const Padding padding = Padding_VALID;
+ ConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
+ ActivationFunctionType_NONE, dilation_width_factor,
+ dilation_height_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
public:
using BaseConvolutionOpModel::BaseConvolutionOpModel;
@@ -441,6 +616,44 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantized) {
}));
}
+TEST_P(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) {
+ // output_multiplier = 1.0118
+ QuantizedConvolutionOpModel quant_op(
+ GetRegistration(), {TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128},
+ {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128},
+ {TensorType_UINT8, {}, -127, 128});
+ ConvolutionOpModel float_op(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
+ std::initializer_list<float> input = {
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ };
+ std::initializer_list<float> filter = {
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ };
+ std::initializer_list<float> bias = {1, 2, 3};
+
+ quant_op.SetInput(input);
+ quant_op.SetFilter(filter);
+ quant_op.SetBias(bias);
+ quant_op.Invoke();
+
+ float_op.SetInput(input);
+ float_op.SetFilter(filter);
+ float_op.SetBias(bias);
+ float_op.Invoke();
+
+ EXPECT_THAT(quant_op.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
QuantizedConvolutionOpModel m(GetRegistration(),
{TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
@@ -468,6 +681,257 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
}));
}
+TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const int dilation_width_factor = 3;
+ const int dilation_height_factor = 3;
+ const Padding padding = Padding_VALID;
+ QuantizedConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, stride_width, stride_height, padding,
+ ActivationFunctionType_NONE, dilation_width_factor,
+ dilation_height_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
+class HybridConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(filter_, f);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ PopulateTensor(bias_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST_P(ConvolutionOpTest, SimpleTestHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_UINT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ // Example: we get 17.1577 instead of 17.
+ //
+ // Second batch:
+ // 1 2 3 4 -> 32 64 95 127 with scale factor 127/4.
+ // 1 2 3 4 32 64 95 127
+ //
+ // First filter:
+ // 1 2 -> 32 64 with scale factor of 127/4.
+ // 3 4 95 127
+ //
+ // The left half of the input gives us 16288. Multiply by (4/127)^2 for
+ // dequantization and adding 1 for the bias gives us the result. and adding
+ // the bias gives us the result.
+ //
+ // The optimized kernel converts the input into this matrix via Im2Col
+ //
+ // 1 1 2 2
+ // 1 1 2 2
+ // 1 2 1 2
+ // 3 4 3 4
+ //
+ // and multiplies it with the filter directly.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 0.16)));
+}
+
+// This test's output is equivalent to the SimpleTestHybrid
+// because we break each input into two channels, each with half of the value,
+// while keeping the filters for each channel equivalent.
+//
+// 2 * (A/2) * B = A * B, where the left side is this new test.
+TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannels) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+ m.SetFilter({
+ 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter
+ -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter
+ -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 0.16)));
+}
+
+TEST_P(ConvolutionOpTest, PointwiseHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ // Example: we get 3.03156 instead of 3.
+ //
+ // Second batch:
+ // 0.5 0.5 1 1 1.5 1.5 2 2 -> 32 32 64 64 95 95 127 127 with scale factor
+ // 127/2. We care about the two 64's.
+ //
+ // Filter:
+ // 64 127 with scale factor of 127/2.
+ //
+ // (64 * 64 + 64 * 127) * (2/127)^2 gives us the expected result.
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 1.5, 1.5, 1.5, // first batch, row = 1
+ 3., 3., 3., 3., // first batch, row = 2
+ 1.5, 3., 4.5, 6., // second batch, row = 1
+ 1.5, 3., 4.5, 6., // second batch, row = 2
+ },
+ 0.0316)));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {2, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ },
+ 0.0474)));
+}
+
INSTANTIATE_TEST_CASE_P(
ConvolutionOpTest, ConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 3ad8d7d4e1..19958844a1 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
@@ -127,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto compute_out_size = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
int out_height =
- compute_out_size(height, filter_height, params->stride_height);
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
- data->padding.height = ComputePadding(params->stride_height, 1, height,
- filter_height, out_height);
+ data->padding.height =
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, out_height);
data->padding.width =
- ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -151,8 +155,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
- QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
- &data->output_shift);
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
&data->output_activation_max);
@@ -172,25 +177,34 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
- void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
- const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, float, float, float*, const Dims<4>&);
+ void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&,
+ const float*, const RuntimeShape&, const float*,
+ const RuntimeShape&, const float*, const RuntimeShape&,
+ float*);
if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
}
- depthwise_conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
- params->depth_multiplier, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output));
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ depthwise_conv(op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(output), GetTensorData<float>(output));
}
template <KernelType kernel_type>
@@ -202,25 +216,38 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
- const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int32, int32, int, int32,
- int32, uint8*, const Dims<4>&);
+ void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&,
+ const uint8*, const RuntimeShape&, const uint8*,
+ const RuntimeShape&, const int32*, const RuntimeShape&,
+ uint8*);
+
if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
}
- depthwise_conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
- params->depth_multiplier, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+ depthwise_conv(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(filter),
+ GetTensorData<uint8_t>(filter), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
template <KernelType kernel_type>
@@ -247,7 +274,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bias, output);
break;
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index 1439c8bce1..4a33a0319d 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -14,12 +14,24 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAreArray;
@@ -28,9 +40,12 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
public:
// TODO(ahentz): Also test different activation types, bias, padding types,
// stride values.
- BaseDepthwiseConvolutionOpModel(const TensorData& input,
+ BaseDepthwiseConvolutionOpModel(TfLiteRegistration* registration,
+ const TensorData& input,
const TensorData& filter,
- const TensorData& output) {
+ const TensorData& output,
+ Padding padding_type,
+ int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -47,12 +62,6 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
}
output_ = AddOutput(output);
- if (input.type != TensorType_FLOAT32) {
- // The following is required by quantized inference. It is the unittest's
- // responsibility to make sure the output scale falls into the correct
- // range.
- CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
- }
int input_depth = GetShape(input_)[3];
int output_depth = GetShape(filter_)[3];
@@ -61,10 +70,14 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
SetBuiltinOp(
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
- ActivationFunctionType_NONE)
+ CreateDepthwiseConv2DOptions(builder_, padding_type, 1, 1, depth_mul,
+ ActivationFunctionType_NONE,
+ dilation_factor, dilation_factor)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_DEPTHWISE_CONV_2D, registration);
+
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
@@ -90,10 +103,25 @@ class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-TEST(DepthwiseConvolutionOpTest, SimpleTest) {
- DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_DEPTHWISE_CONVOLUTION_REF()},
+ {"GenericOptimized",
+ ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT()},
+ {"NeonOptimized", ops::builtin::Register_DEPTHWISE_CONVOLUTION_NEON_OPT()},
+});
+
+class DepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleTest) {
+ DepthwiseConvolutionOpModel m(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -116,6 +144,94 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ DepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, Padding_VALID, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ DepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -140,13 +256,20 @@ class QuantizedDepthwiseConvolutionOpModel
}
};
+class QuantizedDepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// In this test we set the input and output scales so that the results match
// exactly the 'non-quantized' version.
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
QuantizedDepthwiseConvolutionOpModel m(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
- {TensorType_UINT8, {}, -127, 128});
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -176,6 +299,152 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
}));
}
+TEST_P(QuantizedDepthwiseConvolutionOpTest,
+ SimpleTestQuantizedFilterMultiplierGreaterThan1) {
+ QuantizedDepthwiseConvolutionOpModel quant_op(
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
+ DepthwiseConvolutionOpModel float_op(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_FLOAT32, {1, 2, 2, 4}},
+ {TensorType_FLOAT32, {}}, Padding_VALID);
+
+ std::initializer_list<float> input = {
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ };
+ std::initializer_list<float> filter = {
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ };
+ std::initializer_list<float> bias = {1, 2, 3, 4};
+
+ quant_op.SetInput(input);
+ quant_op.SetFilter(filter);
+ quant_op.SetBias(bias);
+ quant_op.Invoke();
+
+ float_op.SetInput(input);
+ float_op.SetFilter(filter);
+ float_op.SetBias(bias);
+ float_op.Invoke();
+
+ EXPECT_THAT(quant_op.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
+}
+
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, Padding_VALID, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
+INSTANTIATE_TEST_CASE_P(
+ QuantizedDepthwiseConvolutionOpTest, QuantizedDepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 672b2170e4..59bf64e0af 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -36,6 +36,21 @@ struct OpContext {
TfLiteTensor* output;
};
+struct OpData {
+ // This boolean value is only used when the input tensor is constant.
+ bool float_dequantized_weights_initialized;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData();
+ op_data->float_dequantized_weights_initialized = false;
+ return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -45,28 +60,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8);
op_context.output->type = kTfLiteFloat32;
+ // If the input tensor is constant, we can persist the dequantized value in
+ // the output tensor. Otherwise we run dequantize upon each eval.
+ if (IsConstantTensor(op_context.input)) {
+ op_context.output->allocation_type = kTfLiteArenaRwPersistent;
+ }
return context->ResizeTensor(context, op_context.output,
TfLiteIntArrayCopy(op_context.input->dims));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
OpContext op_context(context, node);
+ if (IsConstantTensor(op_context.input) &&
+ op_data->float_dequantized_weights_initialized) {
+ return kTfLiteOk;
+ }
- auto zero_point = op_context.input->params.zero_point;
- auto scale = op_context.input->params.scale;
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = op_context.input->params.zero_point;
+ op_params.scale = op_context.input->params.scale;
+ optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+ GetTensorData<uint8_t>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
+
+ if (IsConstantTensor(op_context.input)) {
+ op_data->float_dequantized_weights_initialized = true;
+ }
- optimized_ops::Dequantize(GetTensorData<uint8_t>(op_context.input),
- GetTensorDims(op_context.input), zero_point, scale,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
return kTfLiteOk;
}
} // namespace dequantize
TfLiteRegistration* Register_DEQUANTIZE_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare,
- dequantize::Eval};
+ static TfLiteRegistration r = {dequantize::Init, dequantize::Free,
+ dequantize::Prepare, dequantize::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
new file mode 100644
index 0000000000..e21dc5ced9
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -0,0 +1,591 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <numeric>
+#include <vector>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace detection_postprocess {
+
+// Input tensors
+constexpr int kInputTensorBoxEncodings = 0;
+constexpr int kInputTensorClassPredictions = 1;
+constexpr int kInputTensorAnchors = 2;
+
+// Output tensors
+constexpr int kOutputTensorDetectionBoxes = 0;
+constexpr int kOutputTensorDetectionClasses = 1;
+constexpr int kOutputTensorDetectionScores = 2;
+constexpr int kOutputTensorNumDetections = 3;
+
+constexpr int kNumCoordBox = 4;
+constexpr int kBatchSize = 1;
+
+// Object Detection model produces axis-aligned boxes in two formats:
+// BoxCorner represents the upper right (xmin, ymin) and
+// lower left corner (xmax, ymax).
+// CenterSize represents the center (xcenter, ycenter), height and width.
+// BoxCornerEncoding and CenterSizeEncoding are related as follows:
+// ycenter = y / y_scale * anchor.h + anchor.y;
+// xcenter = x / x_scale * anchor.w + anchor.x;
+// half_h = 0.5*exp(h/ h_scale)) * anchor.h;
+// half_w = 0.5*exp(w / w_scale)) * anchor.w;
+// ymin = ycenter - half_h
+// ymax = ycenter + half_h
+// xmin = xcenter - half_w
+// xmax = xcenter + half_w
+struct BoxCornerEncoding {
+ float ymin;
+ float xmin;
+ float ymax;
+ float xmax;
+};
+
+struct CenterSizeEncoding {
+ float y;
+ float x;
+ float h;
+ float w;
+};
+// We make sure that the memory allocations are contiguous with static assert.
+static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
+ "Size of BoxCornerEncoding is 4 float values");
+static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
+ "Size of CenterSizeEncoding is 4 float values");
+
+struct OpData {
+ int max_detections;
+ int max_classes_per_detection;
+ float non_max_suppression_score_threshold;
+ float intersection_over_union_threshold;
+ int num_classes;
+ CenterSizeEncoding scale_values;
+ // Indices of Temporary tensors
+ int decoded_boxes_index;
+ int scores_index;
+ int active_candidate_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+ op_data->max_detections = m["max_detections"].AsInt32();
+ op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
+ op_data->non_max_suppression_score_threshold =
+ m["nms_score_threshold"].AsFloat();
+ op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
+ op_data->num_classes = m["num_classes"].AsInt32();
+ op_data->scale_values.y = m["y_scale"].AsFloat();
+ op_data->scale_values.x = m["x_scale"].AsFloat();
+ op_data->scale_values.h = m["h_scale"].AsFloat();
+ op_data->scale_values.w = m["w_scale"].AsFloat();
+ context->AddTensors(context, 1, &op_data->decoded_boxes_index);
+ context->AddTensors(context, 1, &op_data->scores_index);
+ context->AddTensors(context, 1, &op_data->active_candidate_index);
+ return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+// TODO(chowdhery): Add to kernel_util.h
+TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
+ std::initializer_list<int> values) {
+ TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
+ int index = 0;
+ for (int v : values) {
+ size->data[index] = v;
+ ++index;
+ }
+ return context->ResizeTensor(context, tensor, size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ // Inputs: box_encodings, scores, anchors
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* input_class_predictions =
+ GetInput(context, node, kInputTensorClassPredictions);
+ const TfLiteTensor* input_anchors =
+ GetInput(context, node, kInputTensorAnchors);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
+ // number of detected boxes
+ const int num_detected_boxes =
+ op_data->max_detections * op_data->max_classes_per_detection;
+
+ // Outputs: detection_boxes, detection_scores, detection_classes,
+ // num_detections
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
+ // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
+ TfLiteTensor* detection_boxes =
+ GetOutput(context, node, kOutputTensorDetectionBoxes);
+ detection_boxes->type = kTfLiteFloat32;
+ SetTensorSizes(context, detection_boxes,
+ {kBatchSize, num_detected_boxes, kNumCoordBox});
+
+ // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
+ TfLiteTensor* detection_classes =
+ GetOutput(context, node, kOutputTensorDetectionClasses);
+ detection_classes->type = kTfLiteFloat32;
+ SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
+
+ // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
+ TfLiteTensor* detection_scores =
+ GetOutput(context, node, kOutputTensorDetectionScores);
+ detection_scores->type = kTfLiteFloat32;
+ SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
+
+ // Output Tensor num_detections: size is set to 1
+ TfLiteTensor* num_detections =
+ GetOutput(context, node, kOutputTensorNumDetections);
+ num_detections->type = kTfLiteFloat32;
+ // TODO (chowdhery): Make it a scalar when available
+ SetTensorSizes(context, num_detections, {1});
+
+ // Temporary tensors
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(3);
+ node->temporaries->data[0] = op_data->decoded_boxes_index;
+ node->temporaries->data[1] = op_data->scores_index;
+ node->temporaries->data[2] = op_data->active_candidate_index;
+
+ // decoded_boxes
+ TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
+ decoded_boxes->type = kTfLiteFloat32;
+ decoded_boxes->allocation_type = kTfLiteArenaRw;
+ SetTensorSizes(context, decoded_boxes,
+ {input_box_encodings->dims->data[1], kNumCoordBox});
+
+ // scores
+ TfLiteTensor* scores = &context->tensors[op_data->scores_index];
+ scores->type = kTfLiteFloat32;
+ scores->allocation_type = kTfLiteArenaRw;
+ SetTensorSizes(context, scores,
+ {input_class_predictions->dims->data[1],
+ input_class_predictions->dims->data[2]});
+
+ // active_candidate
+ TfLiteTensor* active_candidate =
+ &context->tensors[op_data->active_candidate_index];
+ active_candidate->type = kTfLiteUInt8;
+ active_candidate->allocation_type = kTfLiteArenaRw;
+ SetTensorSizes(context, active_candidate,
+ {input_box_encodings->dims->data[1]});
+
+ return kTfLiteOk;
+}
+
+class Dequantizer {
+ public:
+ Dequantizer(int zero_point, float scale)
+ : zero_point_(zero_point), scale_(scale) {}
+ float operator()(uint8 x) {
+ return (static_cast<float>(x) - zero_point_) * scale_;
+ }
+
+ private:
+ int zero_point_;
+ float scale_;
+};
+
+void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
+ float quant_zero_point, float quant_scale,
+ CenterSizeEncoding* box_centersize) {
+ const uint8* boxes =
+ GetTensorData<uint8>(input_box_encodings) + kNumCoordBox * idx;
+ Dequantizer dequantize(quant_zero_point, quant_scale);
+ box_centersize->y = dequantize(boxes[0]);
+ box_centersize->x = dequantize(boxes[1]);
+ box_centersize->h = dequantize(boxes[2]);
+ box_centersize->w = dequantize(boxes[3]);
+}
+
+template <class T>
+T ReInterpretTensor(const TfLiteTensor* tensor) {
+ // TODO (chowdhery): check float
+ const float* tensor_base = tensor->data.f;
+ return reinterpret_cast<T>(tensor_base);
+}
+
+template <class T>
+T ReInterpretTensor(TfLiteTensor* tensor) {
+ // TODO (chowdhery): check float
+ float* tensor_base = tensor->data.f;
+ return reinterpret_cast<T>(tensor_base);
+}
+
+TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
+ OpData* op_data) {
+ // Parse input tensor boxencodings
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
+ const int num_boxes = input_box_encodings->dims->data[1];
+ TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox);
+ const TfLiteTensor* input_anchors =
+ GetInput(context, node, kInputTensorAnchors);
+
+ // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
+ CenterSizeEncoding box_centersize;
+ CenterSizeEncoding scale_values = op_data->scale_values;
+ CenterSizeEncoding anchor;
+ for (int idx = 0; idx < num_boxes; ++idx) {
+ switch (input_box_encodings->type) {
+ // Quantized
+ case kTfLiteUInt8:
+ DequantizeBoxEncodings(
+ input_box_encodings, idx,
+ static_cast<float>(input_box_encodings->params.zero_point),
+ static_cast<float>(input_box_encodings->params.scale),
+ &box_centersize);
+ DequantizeBoxEncodings(
+ input_anchors, idx,
+ static_cast<float>(input_anchors->params.zero_point),
+ static_cast<float>(input_anchors->params.scale), &anchor);
+ break;
+ // Float
+ case kTfLiteFloat32:
+ box_centersize = ReInterpretTensor<const CenterSizeEncoding*>(
+ input_box_encodings)[idx];
+ anchor =
+ ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
+ break;
+ default:
+ // Unsupported type.
+ return kTfLiteError;
+ }
+
+ float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y;
+ float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x;
+ float half_h =
+ 0.5f * static_cast<float>(std::exp(box_centersize.h / scale_values.h)) *
+ anchor.h;
+ float half_w =
+ 0.5f * static_cast<float>(std::exp(box_centersize.w / scale_values.w)) *
+ anchor.w;
+ TfLiteTensor* decoded_boxes =
+ &context->tensors[op_data->decoded_boxes_index];
+ auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
+ box.ymin = ycenter - half_h;
+ box.xmin = xcenter - half_w;
+ box.ymax = ycenter + half_h;
+ box.xmax = xcenter + half_w;
+ }
+ return kTfLiteOk;
+}
+
+void DecreasingPartialArgSort(const float* values, int num_values,
+ int num_to_sort, int* indices) {
+ std::iota(indices, indices + num_values, 0);
+ std::partial_sort(
+ indices, indices + num_to_sort, indices + num_values,
+ [&values](const int i, const int j) { return values[i] > values[j]; });
+}
+
+void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
+ const float threshold,
+ std::vector<float>* keep_values,
+ std::vector<int>* keep_indices) {
+ for (int i = 0; i < values.size(); i++) {
+ if (values[i] >= threshold) {
+ keep_values->emplace_back(values[i]);
+ keep_indices->emplace_back(i);
+ }
+ }
+}
+
+bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
+ for (int i = 0; i < num_boxes; ++i) {
+ // ymax>=ymin, xmax>=xmin
+ auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
+ if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
+ return false;
+ }
+ }
+ return true;
+}
+
+float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
+ const int i, const int j) {
+ auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
+ auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
+ const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
+ const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
+ if (area_i <= 0 || area_j <= 0) return 0.0;
+ const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
+ const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
+ const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
+ const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
+ const float intersection_area =
+ std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
+ std::max<float>(intersection_xmax - intersection_xmin, 0.0);
+ return intersection_area / (area_i + area_j - intersection_area);
+}
+
+// NonMaxSuppressionSingleClass() is O(n^2) pairwise comparison between boxes
+// It assumes all boxes are good in beginning and sorts based on the scores.
+// If lower-scoring box has too much overlap with a higher-scoring box,
+// we get rid of the lower-scoring box.
+TfLiteStatus NonMaxSuppressionSingleClassHelper(
+ TfLiteContext* context, TfLiteNode* node, OpData* op_data,
+ const std::vector<float>& scores, std::vector<int>* selected) {
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* decoded_boxes =
+ &context->tensors[op_data->decoded_boxes_index];
+ const int num_boxes = input_box_encodings->dims->data[1];
+ const int max_detections = op_data->max_detections;
+ const float non_max_suppression_score_threshold =
+ op_data->non_max_suppression_score_threshold;
+ const float intersection_over_union_threshold =
+ op_data->intersection_over_union_threshold;
+ // Maximum detections should be positive.
+ TF_LITE_ENSURE(context, (max_detections >= 0));
+ // intersection_over_union_threshold should be positive
+ // and should be less than 1.
+ TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
+ (intersection_over_union_threshold <= 1.0f));
+ // Validate boxes
+ TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
+
+ // threshold scores
+ std::vector<int> keep_indices;
+ // TODO (chowdhery): Remove the dynamic allocation and replace it
+ // with temporaries, esp for std::vector<float>
+ std::vector<float> keep_scores;
+ SelectDetectionsAboveScoreThreshold(
+ scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
+
+ int num_scores_kept = keep_scores.size();
+ std::vector<int> sorted_indices;
+ sorted_indices.resize(num_scores_kept);
+ DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept,
+ sorted_indices.data());
+
+ const int num_boxes_kept = num_scores_kept;
+ const int output_size = std::min(num_boxes_kept, max_detections);
+ selected->clear();
+ TfLiteTensor* active_candidate =
+ &context->tensors[op_data->active_candidate_index];
+ TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes);
+ int num_active_candidate = num_boxes_kept;
+ uint8_t* active_box_candidate = (active_candidate->data.uint8);
+ for (int row = 0; row < num_boxes_kept; row++) {
+ active_box_candidate[row] = 1;
+ }
+
+ for (int i = 0; i < num_boxes_kept; ++i) {
+ if (num_active_candidate == 0 || selected->size() >= output_size) break;
+ if (active_box_candidate[i] == 1) {
+ selected->push_back(keep_indices[sorted_indices[i]]);
+ active_box_candidate[i] = 0;
+ num_active_candidate--;
+ } else {
+ continue;
+ }
+ for (int j = i + 1; j < num_boxes_kept; ++j) {
+ if (active_box_candidate[j] == 1) {
+ float intersection_over_union = ComputeIntersectionOverUnion(
+ decoded_boxes, keep_indices[sorted_indices[i]],
+ keep_indices[sorted_indices[j]]);
+
+ if (intersection_over_union > intersection_over_union_threshold) {
+ active_box_candidate[j] = 0;
+ num_active_candidate--;
+ }
+ }
+ }
+ }
+ return kTfLiteOk;
+}
+
+// This function implements a fast version of Non Maximal Suppression for
+// multiple classes where
+// 1) we keep the top-k scores for each anchor and
+// 2) during NMS, each anchor only uses the highest class score for sorting.
+// 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
+// instead of O(KN^2) where N is the number of anchors and K the number of
+// classes.
+TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
+ TfLiteNode* node,
+ OpData* op_data,
+ const float* scores) {
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* decoded_boxes =
+ &context->tensors[op_data->decoded_boxes_index];
+
+ TfLiteTensor* detection_boxes =
+ GetOutput(context, node, kOutputTensorDetectionBoxes);
+ TfLiteTensor* detection_classes =
+ GetOutput(context, node, kOutputTensorDetectionClasses);
+ TfLiteTensor* detection_scores =
+ GetOutput(context, node, kOutputTensorDetectionScores);
+ TfLiteTensor* num_detections =
+ GetOutput(context, node, kOutputTensorNumDetections);
+
+ const int num_boxes = input_box_encodings->dims->data[1];
+ const int num_classes = op_data->num_classes;
+ const int max_categories_per_anchor = op_data->max_classes_per_detection;
+ // The row index offset is 1 if background class is included and 0 otherwise.
+ const int label_offset = 1;
+ TF_LITE_ENSURE(context, (label_offset != -1));
+ TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
+ const int num_classes_with_background = num_classes + label_offset;
+ const int num_categories_per_anchor =
+ std::min(max_categories_per_anchor, num_classes);
+ std::vector<float> max_scores;
+ max_scores.resize(num_boxes);
+ std::vector<int> sorted_class_indices;
+ sorted_class_indices.resize(num_boxes * num_classes);
+ for (int row = 0; row < num_boxes; row++) {
+ const float* box_scores =
+ scores + row * num_classes_with_background + label_offset;
+ int* class_indices = sorted_class_indices.data() + row * num_classes;
+ DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
+ class_indices);
+ max_scores[row] = box_scores[class_indices[0]];
+ }
+ // Perform non-maximal suppression on max scores
+ std::vector<int> selected;
+ NonMaxSuppressionSingleClassHelper(context, node, op_data, max_scores,
+ &selected);
+ // Allocate output tensors
+ int output_box_index = 0;
+ for (const auto& selected_index : selected) {
+ const float* box_scores =
+ scores + selected_index * num_classes_with_background + label_offset;
+ const int* class_indices =
+ sorted_class_indices.data() + selected_index * num_classes;
+
+ for (int col = 0; col < num_categories_per_anchor; ++col) {
+ int box_offset = num_categories_per_anchor * output_box_index + col;
+ // detection_boxes
+ ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
+ ReInterpretTensor<const BoxCornerEncoding*>(
+ decoded_boxes)[selected_index];
+ // detection_classes
+ detection_classes->data.f[box_offset] = class_indices[col];
+ // detection_scores
+ detection_scores->data.f[box_offset] = box_scores[class_indices[col]];
+ output_box_index++;
+ }
+ }
+ num_detections->data.f[0] = output_box_index;
+ return kTfLiteOk;
+}
+
+void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
+ const int num_boxes,
+ const int num_classes_with_background,
+ const TfLiteTensor* scores) {
+ float quant_zero_point =
+ static_cast<float>(input_class_predictions->params.zero_point);
+ float quant_scale = static_cast<float>(input_class_predictions->params.scale);
+ Dequantizer dequantize(quant_zero_point, quant_scale);
+ const uint8* scores_quant = GetTensorData<uint8>(input_class_predictions);
+ for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
+ scores->data.f[idx] = dequantize(scores_quant[idx]);
+ }
+}
+
+TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
+ TfLiteNode* node, OpData* op_data) {
+ // Get the input tensors
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* input_class_predictions =
+ GetInput(context, node, kInputTensorClassPredictions);
+ const int num_boxes = input_box_encodings->dims->data[1];
+ const int num_classes = op_data->num_classes;
+ TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
+ kBatchSize);
+ TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
+ const int num_classes_with_background =
+ input_class_predictions->dims->data[2];
+
+ TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1));
+
+ const TfLiteTensor* scores;
+ switch (input_class_predictions->type) {
+ case kTfLiteUInt8: {
+ TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
+ DequantizeClassPredictions(input_class_predictions, num_boxes,
+ num_classes_with_background, temporary_scores);
+ scores = temporary_scores;
+ } break;
+ case kTfLiteFloat32:
+ scores = input_class_predictions;
+ break;
+ default:
+ // Unsupported type.
+ return kTfLiteError;
+ }
+ NonMaxSuppressionMultiClassFastHelper(context, node, op_data,
+ GetTensorData<float>(scores));
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ // TODO(chowdhery): Generalize for any batch size
+ TF_LITE_ENSURE(context, (kBatchSize == 1));
+ auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ // These two functions correspond to two blocks in the Object Detection model.
+ // In future, we would like to break the custom op in two blocks, which is
+ // currently not feasible because we would like to input quantized inputs
+ // and do all calculations in float. Mixed quantized/float calculations are
+ // currently not supported in TFLite.
+
+ // This fills in temporary decoded_boxes
+ // by transforming input_box_encodings and input_anchors from
+ // CenterSizeEncodings to BoxCornerEncoding
+ DecodeCenterSizeBoxes(context, node, op_data);
+ // This fills in the output tensors
+ // by choosing effective set of decoded boxes
+ // based on Non Maximal Suppression, i.e. selecting
+ // highest scoring non-overlapping boxes.
+ NonMaxSuppressionMultiClass(context, node, op_data);
+
+ return kTfLiteOk;
+}
+} // namespace detection_postprocess
+
+TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
+ static TfLiteRegistration r = {detection_postprocess::Init,
+ detection_postprocess::Free,
+ detection_postprocess::Prepare,
+ detection_postprocess::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
new file mode 100644
index 0000000000..1e8caebd82
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class BaseDetectionPostprocessOpModel : public SingleOpModel {
+ public:
+ BaseDetectionPostprocessOpModel(const TensorData& input1,
+ const TensorData& input2,
+ const TensorData& input3,
+ const TensorData& output1,
+ const TensorData& output2,
+ const TensorData& output3,
+ const TensorData& output4) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ input3_ = AddInput(input3);
+ output1_ = AddOutput(output1);
+ output2_ = AddOutput(output2);
+ output3_ = AddOutput(output3);
+ output4_ = AddOutput(output4);
+
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("max_detections", 3);
+ fbb.Int("max_classes_per_detection", 1);
+ fbb.Float("nms_score_threshold", 0.0);
+ fbb.Float("nms_iou_threshold", 0.5);
+ fbb.Int("num_classes", 2);
+ fbb.Float("y_scale", 10.0);
+ fbb.Float("x_scale", 10.0);
+ fbb.Float("h_scale", 5.0);
+ fbb.Float("w_scale", 5.0);
+ });
+ fbb.Finish();
+ SetCustomOp("TFLite_Detection_PostProcess", fbb.GetBuffer(),
+ Register_DETECTION_POSTPROCESS);
+ BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+ int input3() { return input3_; }
+
+ template <class T>
+ void SetInput1(std::initializer_list<T> data) {
+ PopulateTensor<T>(input1_, data);
+ }
+
+ template <class T>
+ void SetInput2(std::initializer_list<T> data) {
+ PopulateTensor<T>(input2_, data);
+ }
+
+ template <class T>
+ void SetInput3(std::initializer_list<T> data) {
+ PopulateTensor<T>(input3_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput1() {
+ return ExtractVector<T>(output1_);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput2() {
+ return ExtractVector<T>(output2_);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput3() {
+ return ExtractVector<T>(output3_);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput4() {
+ return ExtractVector<T>(output4_);
+ }
+
+ std::vector<int> GetOutputShape1() { return GetTensorShape(output1_); }
+ std::vector<int> GetOutputShape2() { return GetTensorShape(output2_); }
+ std::vector<int> GetOutputShape3() { return GetTensorShape(output3_); }
+ std::vector<int> GetOutputShape4() { return GetTensorShape(output4_); }
+
+ protected:
+ int input1_;
+ int input2_;
+ int input3_;
+ int output1_;
+ int output2_;
+ int output3_;
+ int output4_;
+};
+
+TEST(DetectionPostprocessOpTest, FloatTest) {
+ BaseDetectionPostprocessOpModel m(
+ {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 3}},
+ {TensorType_FLOAT32, {6, 4}}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}});
+
+ // six boxes in center-size encoding
+ m.SetInput1<float>({0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
+ 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
+ // class scores - two classes with background
+ m.SetInput2<float>({0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0.,
+ .5, .4, 0., .3, .2});
+ // six anchors in center-size encoding
+ m.SetInput3<float>({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0,
+ 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0,
+ 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0});
+ // Same boxes in box-corner encoding:
+ // { 0.0, 0.0, 1.0, 1.0,
+ // 0.0, 0.1, 1.0, 1.1,
+ // 0.0, -0.1, 1.0, 0.9,
+ // 0.0, 10.0, 1.0, 11.0,
+ // 0.0, 10.1, 1.0, 11.1,
+ // 0.0, 100.0, 1.0, 101.0}
+ m.Invoke();
+ // detection_boxes
+ // in center-size
+ std::vector<int> output_shape1 = m.GetOutputShape1();
+ EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
+ EXPECT_THAT(
+ m.GetOutput1<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0},
+ 1e-1)));
+ // detection_classes
+ std::vector<int> output_shape2 = m.GetOutputShape2();
+ EXPECT_THAT(output_shape2, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput2<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
+ // detection_scores
+ std::vector<int> output_shape3 = m.GetOutputShape3();
+ EXPECT_THAT(output_shape3, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput3<float>(),
+ ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
+ // num_detections
+ std::vector<int> output_shape4 = m.GetOutputShape4();
+ EXPECT_THAT(output_shape4, ElementsAre(1));
+ EXPECT_THAT(m.GetOutput4<float>(),
+ ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
+}
+
+TEST(DetectionPostprocessOpTest, QuantizedTest) {
+ BaseDetectionPostprocessOpModel m(
+ {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},
+ {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0},
+ {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}});
+ // six boxes in center-size encoding
+ std::vector<std::initializer_list<float>> inputs1 = {
+ {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}};
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[0]);
+ // class scores - two classes with background
+ std::vector<std::initializer_list<float>> inputs2 = {
+ {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., .5, .4, 0., .3,
+ .2}};
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[0]);
+ // six anchors in center-size encoding
+ std::vector<std::initializer_list<float>> inputs3 = {
+ {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0,
+ 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}};
+ m.QuantizeAndPopulate<uint8_t>(m.input3(), inputs3[0]);
+ m.Invoke();
+ // detection_boxes
+ // in center-size
+ std::vector<int> output_shape1 = m.GetOutputShape1();
+ EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
+ EXPECT_THAT(
+ m.GetOutput1<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0},
+ 3e-1)));
+ // detection_classes
+ std::vector<int> output_shape2 = m.GetOutputShape2();
+ EXPECT_THAT(output_shape2, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput2<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
+ // detection_scores
+ std::vector<int> output_shape3 = m.GetOutputShape3();
+ EXPECT_THAT(output_shape3, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput3<float>(),
+ ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
+ // num_detections
+ std::vector<int> output_shape4 = m.GetOutputShape4();
+ EXPECT_THAT(output_shape4, ElementsAre(1));
+ EXPECT_THAT(m.GetOutput4<float>(),
+ ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
+}
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index e52e4fe535..8d4bb51006 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -78,29 +78,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteDivParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_DIV(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv);
+void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_DIV(type, opname, data_type) \
+ tflite::ArithmeticParams op_params; \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, int32_t);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, int32_t);
+ }
} else {
- TF_LITE_DIV(reference_ops, Div);
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, int32_t);
+ } else {
+ TF_LITE_DIV(optimized_ops, Div, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, float);
+ } else {
+ TF_LITE_DIV(reference_ops, Div, float);
+ }
} else {
- TF_LITE_DIV(optimized_ops, Div);
+ if (data->requires_broadcast) {
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, float);
+ } else {
+ TF_LITE_DIV(optimized_ops, Div, float);
+ }
}
}
#undef TF_LITE_DIV
@@ -115,11 +133,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
} else {
- context->ReportError(context,
- "Div only supports FLOAT32 and quantized UINT8 now.");
+ context->ReportError(
+ context,
+ "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/contrib/lite/kernels/div_test.cc
index 276b8289fb..97aa2fe04e 100644
--- a/tensorflow/contrib/lite/kernels/div_test.cc
+++ b/tensorflow/contrib/lite/kernels/div_test.cc
@@ -52,6 +52,13 @@ class FloatDivOpModel : public BaseDivOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerDivOpModel : public BaseDivOpModel {
+ public:
+ using BaseDivOpModel::BaseDivOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
TEST(FloatDivOpTest, NoActivation) {
FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -75,7 +82,7 @@ TEST(FloatDivOpTest, ActivationRELU_N1_TO_1) {
}
TEST(FloatDivOpTest, VariousInputShapes) {
- std::vector<std::initializer_list<int>> test_shapes = {
+ std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]},
@@ -92,7 +99,7 @@ TEST(FloatDivOpTest, VariousInputShapes) {
}
TEST(FloatDivOpTest, WithBroadcast) {
- std::vector<std::initializer_list<int>> test_shapes = {
+ std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]},
@@ -108,6 +115,56 @@ TEST(FloatDivOpTest, WithBroadcast) {
}
}
+TEST(IntegerDivOpTest, NoActivation) {
+ IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-2, 2, -15, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {5, -2, -3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, -1, 5, 1}));
+}
+
+TEST(IntegerDivOpTest, ActivationRELU_N1_TO_1) {
+ IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-2, 2, -12, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, -15, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 0, 1}));
+}
+
+TEST(IntegerDivOpTest, VariousInputShapes) {
+ std::vector<std::vector<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerDivOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 3, 8, 11, -20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 6, 5, -11, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 1, 0, 1, -1, 20}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerDivOpTest, WithBroadcast) {
+ std::vector<std::vector<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerDivOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 21, 7, 8, 11, -123});
+ m.PopulateTensor<int32_t>(m.input2(), {3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-6, 7, 2, 2, 3, -41}))
+ << "With shape number " << i;
+ }
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index f1fdb42624..e542ad0765 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -14,31 +14,100 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "third_party/eigen3/Eigen/Core"
+#include <utility>
+
+#include "tensorflow/contrib/lite/arena_planner.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace eigen_support {
+namespace {
+
+#ifndef EIGEN_DONT_ALIGN
+// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on
+// hardware architecture and build configurations.
+// If the static assertion fails, try to increase `kDefaultTensorAlignment` to
+// in `arena_planner.h` to 32 or 64.
+static_assert(
+ kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
+ "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
+#endif // EIGEN_DONT_ALIGN
+
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ // Takes ownership of 'pool'
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
-struct RefCountedEigenContext {
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ std::unique_ptr<Eigen::ThreadPool> pool_;
+};
+
+struct RefCountedEigenContext : public TfLiteExternalContext {
+ std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
int num_references = 0;
};
+RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedEigenContext*>(
+ context->GetExternalContext(context, kTfLiteEigenContext));
+}
+
+void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) {
+ int num_threads = 4;
+ if (context->recommended_num_threads != -1) {
+ num_threads = context->recommended_num_threads;
+ }
+ ptr->device.reset(); // destroy before we invalidate the thread pool
+ ptr->thread_pool_wrapper.reset(
+ new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
+ ptr->device.reset(
+ new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+
+ auto* ptr = GetEigenContext(context);
+ if (ptr != nullptr) {
+ InitDevice(context, ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
if (context->recommended_num_threads != -1) {
Eigen::setNbThreads(context->recommended_num_threads);
}
ptr = new RefCountedEigenContext;
+ ptr->type = kTfLiteEigenContext;
+ ptr->Refresh = Refresh;
ptr->num_references = 0;
- context->eigen_context = ptr;
+ InitDevice(context, ptr);
+ context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
@@ -46,14 +115,17 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
if (--ptr->num_references == 0) {
delete ptr;
- context->eigen_context = nullptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
}
}
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- Eigen::setNbThreads(num_threads);
- DecrementUsageCounter(context);
+const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
+ auto* ptr = GetEigenContext(context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->device.get();
}
} // namespace eigen_support
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index aa8c351fd8..feb1543f7b 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -15,7 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+namespace EigenForTFLite {
+struct ThreadPoolDevice;
+}
namespace tflite {
namespace eigen_support {
@@ -28,8 +32,8 @@ void IncrementUsageCounter(TfLiteContext* context);
// usages all temporary Eigen objects will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the number of threads that can be used by Eigen.
-void SetNumThreads(TfLiteContext* context, int num_threads);
+const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice(
+ TfLiteContext* context);
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index b719a08394..8c624b3208 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cmath>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -22,43 +22,130 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace elementwise {
+namespace {
-TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
+bool IsNumericSupportedType(const TfLiteType type) {
+ return type == kTfLiteFloat32;
+}
+
+bool IsLogicalSupportedType(const TfLiteType type) {
+ return type == kTfLiteBool;
+}
+
+typedef bool (*IsSupportedType)(TfLiteType);
+template <IsSupportedType>
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
- // Quantized float is not supported yet.
- TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ if (!IsSupportedType(input->type)) {
+ context->ReportError(context, "Current data type %d is not supported.",
+ input->type);
+ return kTfLiteError;
+ }
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
-TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+template <typename T>
+inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
+ T func(T), TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
- switch (input->type) {
- case kTfLiteFloat32: {
- size_t elements = NumElements(input);
- const float* in = GetTensorData<float>(input);
- const float* in_end = in + elements;
- float* out = output->data.f;
- for (; in < in_end; in++, out++) *out = std::sin(*in);
- return kTfLiteOk;
- }
- default: {
- context->ReportError(context, "Only float32 is supported currently");
- return kTfLiteError;
- }
+ TF_LITE_ENSURE_EQ(context, input->type, expected_type);
+ const int64_t num_elements = NumElements(input);
+ const T* in_data = GetTensorData<T>(input);
+ T* out_data = GetTensorData<T>(output);
+ for (int64_t i = 0; i < num_elements; ++i) {
+ out_data[i] = func(in_data[i]);
}
+ return kTfLiteOk;
+}
+
+inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
+ float float_func(float)) {
+ return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
+}
+
+inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
+ bool bool_func(bool)) {
+ return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
+}
+
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, std::sin);
+}
+
+TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, std::log);
+}
+
+TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, std::sqrt);
+}
+
+TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
}
+TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, [](float f) { return f * f; });
+}
+
+TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalLogical(context, node, [](bool v) { return !v; });
+}
+
+} // namespace
} // namespace elementwise
TfLiteRegistration* Register_SIN() {
- static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare,
- elementwise::SinEval};
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::SinEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOG() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::LogEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_SQRT() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::SqrtEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_RSQRT() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::RsqrtEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_SQUARE() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::SquareEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOGICAL_NOT() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
+ elementwise::LogicalNotEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index 412ffb04b9..5dd89a0eae 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -24,25 +24,40 @@ namespace {
using ::testing::ElementsAreArray;
-class SinOpModel : public SingleOpModel {
+class ElementWiseOpBaseModel : public SingleOpModel {
public:
- SinOpModel(std::initializer_list<int> input_shape) {
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0);
- BuildInterpreter({input_shape});
- }
-
int input() const { return input_; }
int output() const { return output_; }
- private:
+ protected:
int input_;
int output_;
};
+class ElementWiseOpFloatModel : public ElementWiseOpBaseModel {
+ public:
+ ElementWiseOpFloatModel(BuiltinOperator op,
+ std::initializer_list<int> input_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(op, BuiltinOptions_NONE, 0);
+ BuildInterpreter({input_shape});
+ }
+};
+
+class ElementWiseOpBoolModel : public ElementWiseOpBaseModel {
+ public:
+ ElementWiseOpBoolModel(BuiltinOperator op,
+ std::initializer_list<int> input_shape) {
+ input_ = AddInput(TensorType_BOOL);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(op, BuiltinOptions_NONE, 0);
+ BuildInterpreter({input_shape});
+ }
+};
+
TEST(ElementWise, Sin) {
- SinOpModel m({1, 1, 4, 1});
+ ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
@@ -50,6 +65,51 @@ TEST(ElementWise, Sin) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Log) {
+ ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+TEST(ElementWise, Sqrt) {
+ ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {0, 1, 2, 4});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+TEST(ElementWise, Rsqrt) {
+ ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+TEST(ElementWise, Square) {
+ ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+TEST(ElementWise, LogicalNot) {
+ ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
+ m.PopulateTensor<bool>(m.input(), {true, false, true, false});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<bool>(m.output()),
+ ElementsAreArray({false, true, false, true}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 7539c0b30d..fe33f98eb0 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -24,11 +24,11 @@ limitations under the License.
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
-// Each item in output is a raw bytes copy of corresponding item in input.
+// Each item in output is a raw bytes copy of the corresponding item in input,
+// or a dequantized value in the case of a uint8 input.
// When indices are out of bound, the ops will not succeed.
//
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -37,8 +37,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -69,11 +69,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, outputSize);
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* output = GetOutput(context, node, 0);
- const TfLiteTensor* lookup = GetInput(context, node, 0);
- const TfLiteTensor* value = GetInput(context, node, 1);
-
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
@@ -91,6 +89,53 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
+ const int row_size = SizeOfDimension(value, 0);
+ const double scaling_factor = value->params.scale;
+
+ // col_size after we flatten tensor into 2D.
+ int col_size = 1;
+ for (int i = 1; i < NumDimensions(value); i++) {
+ col_size *= SizeOfDimension(value, i);
+ }
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ // Dequantize embedding values.
+ // TODO(alanchiao): refactor scalar multiply into separate function
+ // for ease of adding a neon equivalent if ever necessary.
+ for (int j = 0; j < col_size; j++) {
+ const int8_t* value_ptr = reinterpret_cast<int8_t*>(value->data.uint8);
+ output->data.f[j + i * col_size] =
+ value_ptr[j + idx * col_size] * scaling_factor;
+ }
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* value = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (value->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, lookup, value, output);
+ case kTfLiteUInt8:
+ return EvalHybrid(context, node, lookup, value, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+}
+
} // namespace embedding_lookup
TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index d3be36993c..aa75b03990 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -65,8 +65,8 @@ limitations under the License.
#include <algorithm>
#include <cmath>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
index 9b501878f1..4a88d168c6 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -7,13 +7,14 @@ You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
+for the specific language governing permissions and limitations under the
+License.
==============================================================================*/
// Unit test for TFLite Lookup op.
+#include <initializer_list>
#include <iomanip>
#include <vector>
@@ -29,12 +30,13 @@ namespace {
using ::testing::ElementsAreArray;
-class EmbeddingLookupOpModel : public SingleOpModel {
+class BaseEmbeddingLookupOpModel : public SingleOpModel {
public:
- EmbeddingLookupOpModel(std::initializer_list<int> index_shape,
- std::initializer_list<int> weight_shape) {
+ BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape,
+ TensorType weight_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_INT32);
- weight_ = AddInput(TensorType_FLOAT32);
+ weight_ = AddInput(weight_type);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
BuildInterpreter({index_shape, weight_shape});
@@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int weight_;
+ int output_;
+};
+
+class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
+
void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
TfLiteTensor* tensor = interpreter_->tensor(weight_);
int rows = tensor->dims->data[0];
@@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel {
}
}
}
+};
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ HybridEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape)
+ : BaseEmbeddingLookupOpModel(index_shape, weight_shape,
+ TensorType_UINT8) {}
- private:
- int input_;
- int weight_;
- int output_;
+ void SetWeight(std::initializer_list<float> data) {
+ SymmetricQuantizeAndPopulate(weight_, data);
+ }
};
// TODO(ahentz): write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
EmbeddingLookupOpModel m({3}, {3, 2, 4});
- m.PopulateTensor<int>(0, {1, 0, 2});
+ m.SetInput({1, 0, 2});
m.Set3DWeightMatrix(
[](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
@@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) {
})));
}
+TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 8});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index ce03cdfe26..673e7be90a 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
new file mode 100644
index 0000000000..fa1140b19c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -0,0 +1,113 @@
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace expand_dims {
+constexpr int kInput = 0;
+constexpr int kAxis = 1;
+constexpr int kOutput = 0;
+
+namespace {
+TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input,
+ int axis, TfLiteTensor* output) {
+ const TfLiteIntArray& input_dims = *input.dims;
+ if (axis < 0) {
+ axis = input_dims.size + 1 + axis;
+ }
+ TF_LITE_ENSURE(context, axis <= input_dims.size);
+
+ TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1);
+ for (int i = 0; i < output_dims->size; ++i) {
+ if (i < axis) {
+ output_dims->data[i] = input_dims.data[i];
+ } else if (i == axis) {
+ output_dims->data[i] = 1;
+ } else {
+ output_dims->data[i] = input_dims.data[i - 1];
+ }
+ }
+
+ return context->ResizeTensor(context, output, output_dims);
+}
+
+TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
+ const TfLiteTensor& axis, int* axis_value) {
+ TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1);
+ switch (axis.type) {
+ case kTfLiteInt32:
+ *axis_value = *GetTensorData<int32_t>(&axis);
+ return kTfLiteOk;
+ case kTfLiteInt64:
+ *axis_value = *GetTensorData<int64_t>(&axis);
+ return kTfLiteOk;
+ default:
+ return kTfLiteError;
+ }
+}
+
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ if (IsConstantTensor(axis)) {
+ int axis_value;
+ TF_LITE_ENSURE_OK(context,
+ GetAxisValueFromTensor(context, *axis, &axis_value));
+ return ExpandTensorDim(context, *input, axis_value, output);
+ }
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ // Just copy input to output.
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
+ if (IsDynamicTensor(output)) {
+ int axis_value;
+ TF_LITE_ENSURE_OK(context,
+ GetAxisValueFromTensor(context, *axis, &axis_value));
+ TF_LITE_ENSURE_OK(context,
+ ExpandTensorDim(context, *input, axis_value, output));
+ }
+ memcpy(output->data.raw, input->data.raw, input->bytes);
+ return kTfLiteOk;
+}
+
+} // namespace expand_dims
+TfLiteRegistration* Register_EXPAND_DIMS() {
+ static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare,
+ expand_dims::Eval};
+ return &r;
+}
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
new file mode 100644
index 0000000000..a3bc1813db
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -0,0 +1,83 @@
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ExpandDimsOpModel : public SingleOpModel {
+ public:
+ ExpandDimsOpModel(std::initializer_list<int> input_shape,
+ TensorType input_type) {
+ input_ = AddInput(input_type);
+ axis_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(input_type);
+ SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions,
+ 0);
+ BuildInterpreter({input_shape, {1}});
+ }
+ void SetInputFloat(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ void SetAxis(int axis) { PopulateTensor<int32_t>(axis_, {axis}); }
+ std::vector<float> GetValuesFloat() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int axis_;
+ int output_;
+};
+
+TEST(ExpandDimsOpTest, DifferentAxis) {
+ ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32);
+ std::initializer_list<float> values = {-1.f, 1.f, -2.f, 2.f};
+ m.SetInputFloat(values);
+ m.SetAxis(0);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
+
+ m.SetAxis(1);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2}));
+
+ m.SetAxis(2);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1}));
+
+ m.SetAxis(-1);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1}));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
new file mode 100644
index 0000000000..b51af72fe6
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -0,0 +1,95 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace fake_quant {
+
+// This file has reference implementation of FakeQuant.
+enum KernelType {
+ kReference,
+};
+
+struct OpContext {
+ OpContext(TfLiteContext* context, TfLiteNode* node) {
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ const TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ if (params->narrow_range) {
+ context->ReportError(
+ context,
+ "narrow_range FakeQuant is not currently supported at runtime. "
+ "narrow_range is only meant to be applied to weights, not activations");
+ return kTfLiteError;
+ }
+
+ OpContext op_context(context, node);
+ TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims);
+ op_context.output->type = op_context.input->type;
+ return context->ResizeTensor(context, op_context.output, output_dims);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = params->num_bits;
+ op_params.minmax.min = params->min;
+ op_params.minmax.max = params->max;
+ reference_ops::FakeQuant(op_params, GetTensorShape(op_context.input),
+ GetTensorData<float>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
+
+ return kTfLiteOk;
+}
+
+} // namespace fake_quant
+
+TfLiteRegistration* Register_FAKE_QUANT_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare,
+ fake_quant::Eval<fake_quant::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); }
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
new file mode 100644
index 0000000000..11a02f7ed7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FakeQuantOpModel : public SingleOpModel {
+ public:
+ FakeQuantOpModel(const TensorData& input, const TensorType& output, float min,
+ float max, int num_bits) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions,
+ CreateFakeQuantOptions(builder_, min, max, num_bits).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ template <class T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor(input_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(FakeQuantOpTest, FloatPositiveRange8Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({0, 1, 0.25098, 0.498039, 0.443137, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange8Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.896471, 0.247059, 0.501176, 0.444706, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatPositiveRange16Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, 1, 0.250004, 0.500008, 0.44445, 1.5259e-05})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange16Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.900014, 0.249998, 0.499995, 0.444431, 0})));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index 697b777693..59ff77f35b 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -41,8 +41,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+
return kTfLiteOk;
}
} // namespace floor
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
new file mode 100644
index 0000000000..5d62cd2755
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -0,0 +1,146 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace floor_div {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for floor_div op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+template <typename T>
+T FloorDiv(T input1, T input2) {
+ return std::floor(std::divides<double>()(static_cast<double>(input1),
+ static_cast<double>(input2)));
+}
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Reinterprete the opaque data provided by user.
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteInt32) {
+ context->ReportError(context, "Currently floor_div only supports int32.");
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
+ TfLiteTensor* output) {
+ const T* denominator_data = GetTensorData<T>(input2);
+
+ // Validate the denominator.
+ for (int i = 0; i < NumElements(input2); ++i) {
+ if (std::equal_to<T>()(denominator_data[i], 0)) {
+ context->ReportError(context, "Division by 0");
+ return kTfLiteError;
+ }
+ }
+ if (requires_broadcast) {
+ reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), denominator_data, GetTensorShape(output),
+ GetTensorData<T>(output), FloorDiv<T>);
+ } else {
+ reference_ops::BinaryFunction<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input1->type) {
+ case kTfLiteInt32: {
+ return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
+ input2, output);
+ }
+ default: {
+ context->ReportError(context, "Currently floor_div only supports int32.");
+ return kTfLiteError;
+ }
+ }
+}
+
+} // namespace
+} // namespace floor_div
+
+TfLiteRegistration* Register_FLOOR_DIV() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
+ floor_div::Prepare, floor_div::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/floor_div_test.cc b/tensorflow/contrib/lite/kernels/floor_div_test.cc
new file mode 100644
index 0000000000..eea69b61ac
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_div_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class FloorDivModel : public SingleOpModel {
+ public:
+ FloorDivModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions,
+ CreateFloorDivOptions(builder_).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(PowOpModel, Simple) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, 9, 11, 3});
+ model.PopulateTensor<int32_t>(model.input2(), {2, 2, 3, 4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(5, 4, 3, 0));
+}
+
+TEST(PowOpModel, NegativeValue) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, -9, -11, 7});
+ model.PopulateTensor<int32_t>(model.input2(), {2, 2, -3, -4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(5, -5, 3, -2));
+}
+
+TEST(PowOpModel, BroadcastFloorDiv) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, -9, -11, 7});
+ model.PopulateTensor<int32_t>(model.input2(), {-3});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(-4, 3, 3, -3));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index a486b81d76..f6d2f76dbe 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -21,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
@@ -63,6 +62,7 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
+constexpr int kShuffledInputWorkspaceTensor = 1;
constexpr int kScratchBufferTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -70,7 +70,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
gemm_support::IncrementUsageCounter(context);
- auto* op_data = new OpData;
+ auto* op_data = new OpData();
context->AddTensors(context, 1, &op_data->input_quantized_index);
return op_data;
}
@@ -87,7 +87,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ // Shuffled formats need a workspace to store the shuffled input activations.
+ const int expected_outputs_count =
+ params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
+ : 2;
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
@@ -101,16 +105,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_size *= input->dims->data[i];
}
+ TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
const int batch_size = input_size / filter->dims->data[1];
const int num_units = filter->dims->data[0];
- TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]);
+ TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]);
if (bias) {
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
- TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
-
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
TfLiteType data_type = input->type;
@@ -118,11 +121,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
- QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
- &data->output_shift);
- CalculateActivationRangeUint8(params->activation, output,
- &data->output_activation_min,
- &data->output_activation_max);
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
+ TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
+ context, params->activation, output, &data->output_activation_min,
+ &data->output_activation_max));
}
// If we have to perform on-the-fly quantization (with quantized weights and
@@ -218,11 +222,8 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node,
tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
}
- // TODO(mirkov): change std::minmax_element with a vectorized call.
- auto minmax_element =
- std::minmax_element(input->data.f, input->data.f + total_input_size);
// Save matrix multiplication computation for all zero input.
- if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) {
+ if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) {
tensor_utils::ApplyActivationToVector(output->data.f,
batch_size * num_units,
params->activation, output->data.f);
@@ -280,30 +281,57 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
-#define TF_LITE_FULLY_CONNECTED(type) \
- type::FullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output), gemm_context)
+#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.input_offset = input_offset; \
+ op_params.weights_offset = filter_offset; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::FullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<output_data_type>(output), \
+ gemm_context); \
+ }
if (kernel_type == kReference) {
- TF_LITE_FULLY_CONNECTED(reference_ops);
- } else if (kernel_type == kPie) {
- if (input->type == kTfLiteFloat32) {
- // Pie currently only supports quantized models and float inputs/outputs.
- TfLiteTensor* input_quantized =
- &context->tensors[node->temporaries->data[0]];
- return EvalPieQuantized(context, node, params, data, input, filter, bias,
- input_quantized, output);
- } else {
- // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
- // we just defer to the MINI ones.
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
}
+ } else if (kernel_type == kPie && input->type == kTfLiteFloat32) {
+ // Pie currently only supports quantized models and float inputs/outputs.
+ TfLiteTensor* input_quantized =
+ &context->tensors[node->temporaries->data[0]];
+ return EvalPieQuantized(context, node, params, data, input, filter, bias,
+ input_quantized, output);
} else {
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
+ }
}
#undef TF_LITE_FULLY_CONNECTED
@@ -311,19 +339,67 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
+TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params,
+ OpData* data, const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias,
+ TfLiteTensor* output,
+ TfLiteTensor* shuffled_input_workspace) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ // TODO(b/110697972) decide more consistently if / how / where we want
+ // to perform this kind of runtime data type checks.
+ if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 ||
+ bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 ||
+ shuffled_input_workspace->type != kTfLiteUInt8) {
+ context->ReportError(context, "Unexpected data type");
+ return kTfLiteError;
+ }
+
+#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::ShuffledFullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<int16_t>(output), \
+ GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
+ }
+ if (kernel_type == kReference) {
+ TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
+ } else {
+ TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_SHUFFLED_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_FULLY_CONNECTED(type) \
- type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(filter), GetTensorDims(filter), \
- GetTensorData<float>(bias), GetTensorDims(bias), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.float_activation_min = output_activation_min; \
+ op_params.float_activation_max = output_activation_max; \
+ type::FullyConnected(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(filter), \
+ GetTensorData<float>(filter), GetTensorShape(bias), \
+ GetTensorData<float>(bias), GetTensorShape(output), \
+ GetTensorData<float>(output)); \
+ }
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops);
} else if (kernel_type == kPie) {
@@ -354,10 +430,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalFloat<kernel_type>(context, node, params, data, input, filter,
bias, output);
case kTfLiteUInt8:
- return EvalQuantized<kernel_type>(context, node, params, data, input,
- filter, bias, output);
+ if (params->weights_format ==
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
+ TfLiteTensor* shuffled_input_workspace =
+ GetOutput(context, node, kShuffledInputWorkspaceTensor);
+ return EvalShuffledQuantized<kernel_type>(context, node, params, data,
+ input, filter, bias, output,
+ shuffled_input_workspace);
+ } else if (params->weights_format ==
+ kTfLiteFullyConnectedWeightsFormatDefault) {
+ return EvalQuantized<kernel_type>(context, node, params, data, input,
+ filter, bias, output);
+ } else {
+ context->ReportError(context,
+ "Unhandled fully-connected weights format");
+ return kTfLiteError;
+ }
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ filter->type);
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
index 05dd028b48..08b4320946 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
// Unit test for TFLite FULLY_CONNECTED op.
#include <iomanip>
+#include <random>
#include <vector>
#include <gmock/gmock.h>
@@ -133,9 +134,12 @@ static float fully_connected_golden_output[] = {
class BaseFullyConnectedOpModel : public SingleOpModel {
public:
// TODO(ahentz): test different activation types too.
- BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
- int batches, const TensorData& input,
- const TensorData& output = {TensorType_FLOAT32})
+ BaseFullyConnectedOpModel(
+ TfLiteRegistration* registration, int units, int batches,
+ const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
+ ActivationFunctionType activation_func = ActivationFunctionType_RELU,
+ FullyConnectedOptionsWeightsFormat weights_format =
+ FullyConnectedOptionsWeightsFormat_DEFAULT)
: batches_(batches), units_(units) {
int total_input_size = 1;
for (int i = 0; i < input.shape.size(); ++i) {
@@ -159,10 +163,13 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
}
output_ = AddOutput(output);
+ if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) {
+ AddOutput({TensorType_UINT8, input.shape});
+ }
SetBuiltinOp(
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+ CreateFullyConnectedOptions(builder_, activation_func, weights_format)
.Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, registration);
@@ -188,13 +195,11 @@ class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
public:
using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
- void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
- void SetWeights(std::initializer_list<float> f) {
- PopulateTensor(weights_, f);
- }
+ void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
- void SetInput(std::initializer_list<float> data) {
+ void SetInput(const std::vector<float>& data) {
PopulateTensor(input_, data);
}
void SetInput(int offset, float* begin, float* end) {
@@ -208,20 +213,50 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
public:
using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
- void SetBias(std::initializer_list<float> data) {
+ void SetBias(const std::vector<float>& data) {
QuantizeAndPopulate<int32_t>(bias_, data);
}
- void SetWeights(std::initializer_list<float> data) {
+ void SetWeights(const std::vector<float>& data) {
QuantizeAndPopulate<uint8_t>(weights_, data);
}
- void SetInput(std::initializer_list<float> data) {
+ void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
+ int output_depth) {
+ std::vector<float> shuffled_data(data.size());
+ CHECK_EQ(input_depth % 16, 0);
+ CHECK_EQ(output_depth % 4, 0);
+ float* shuffled_data_ptr = shuffled_data.data();
+ for (int block_o = 0; block_o < output_depth; block_o += 4) {
+ for (int block_i = 0; block_i < input_depth; block_i += 16) {
+ for (int o = 0; o < 4; o++) {
+ for (int i = 0; i < 16; i++) {
+ *shuffled_data_ptr++ =
+ data[(block_o + o) * input_depth + block_i + i];
+ }
+ }
+ }
+ }
+ TfLiteTensor* t = interpreter_->tensor(weights_);
+ auto quantized_data =
+ Quantize<uint8_t>(shuffled_data, t->params.scale, t->params.zero_point);
+ for (uint8_t& q : quantized_data) {
+ q ^= 0x80;
+ }
+ PopulateTensor(weights_, 0, quantized_data.data(),
+ quantized_data.data() + quantized_data.size());
+ }
+ void SetInput(const std::vector<float>& data) {
QuantizeAndPopulate<uint8_t>(input_, data);
}
- std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ template <typename T>
std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
+ return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+ GetZeroPoint(output_));
}
};
@@ -256,12 +291,12 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
ops::builtin::Register_FULLY_CONNECTED_PIE());
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
}
- void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
- void SetWeights(std::initializer_list<float> data) {
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
+ void SetWeights(const std::vector<float>& data) {
SymmetricQuantizeAndPopulate(weights_, data);
}
- void SetInput(std::initializer_list<float> f) { PopulateTensor(input_, f); }
+ void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -340,6 +375,24 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
+TEST_P(FloatFullyConnectedOpTest, SimpleTest2) {
+ FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 2}});
+ m.SetWeights({
+ 2, 4, // u = 0
+ });
+ m.SetBias({1});
+
+ m.SetInput({
+ 1, 2, // b = 0
+ 2, 1, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
+}
+
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
@@ -350,7 +403,38 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+TEST_P(QuantizedFullyConnectedOpTest,
+ SimpleTestQuantizedOutputMultiplierGreaterThan1) {
+ // real_multiplier = 2.
+ QuantizedFullyConnectedOpModel m(
+ GetRegistration(), /*units=*/3, /*batches*/ 2,
+ /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
+ /*output=*/{TensorType_UINT8, {}, -63.5, 64});
+
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetBias({1, 2, 3});
@@ -361,11 +445,136 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
- 24, 25, 26, //
- 58, 59, 60, //
- })));
- EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, // first batch
+ 58, 59, 60, // second batch
+ })));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAre(175, 177, 179, 243, 245, 247));
+}
+
+void SimpleTestQuantizedInt16OutputCase(
+ TfLiteRegistration* registration, int input_depth, int output_depth,
+ int batches, FullyConnectedOptionsWeightsFormat weights_format) {
+ const uint8_t kWeightsZeroPoint = 128;
+ const float kWeightsScale = 1.f / 128.f;
+ const uint8_t kInputZeroPoint = 128;
+ const float kInputScale = 1.f / 128.f;
+ const float kInputMin = (0 - kInputZeroPoint) * kInputScale;
+ const float kInputMax = (255 - kInputZeroPoint) * kInputScale;
+ // Output ranges in [-8..8] encoded as int16
+ const float kOutputScale = 8.f / 32768.f;
+ const float kOutputMin = -32768 * kOutputScale;
+ const float kOutputMax = 32767 * kOutputScale;
+
+ QuantizedFullyConnectedOpModel m(
+ registration, output_depth, batches,
+ /*input=*/
+ {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax},
+ /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax},
+ /*activation_func=*/ActivationFunctionType_NONE, weights_format);
+
+ std::mt19937 random_engine;
+ std::uniform_int_distribution<uint8_t> weights_dist;
+
+ std::vector<float> weights_data(input_depth * output_depth);
+ for (auto& w : weights_data) {
+ uint8_t q = weights_dist(random_engine);
+ w = (q - kWeightsZeroPoint) * kWeightsScale;
+ }
+
+ // Based on weights_format, enforce any shape requirement for that format/path
+ // and set the (possibly shuffled) weights.
+ switch (weights_format) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ m.SetWeights(weights_data);
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ // The shuffled path currently supports only a restrictive subset of
+ // shapes, described by the following assertions:
+ CHECK_EQ(input_depth % 16, 0);
+ CHECK_EQ(output_depth % 4, 0);
+ CHECK(batches == 1 || batches == 4);
+ m.ShuffleAndSetWeights(weights_data, input_depth, output_depth);
+ break;
+ default:
+ LOG(FATAL) << "Unhandled weights format";
+ }
+
+ std::uniform_int_distribution<uint8_t> input_dist;
+ std::vector<float> input_data(input_depth * batches);
+ for (auto& i : input_data) {
+ uint8_t q = input_dist(random_engine);
+ i = (q - kInputZeroPoint) * kInputScale;
+ }
+
+ std::vector<float> bias_data(output_depth);
+ // As the output ranges in [-8, 8], it's reasonable to have bias values
+ // in [-1, 1], this won't result in too much saturation.
+ std::uniform_real_distribution<float> bias_dist(-1.f, 1.f);
+ for (auto& b : bias_data) {
+ b = bias_dist(random_engine);
+ }
+
+ m.SetBias(bias_data);
+ m.SetInput(input_data);
+
+ m.Invoke();
+
+ std::vector<float> expected_output_data(output_depth * batches);
+ for (int b = 0; b < batches; b++) {
+ for (int o = 0; o < output_depth; o++) {
+ float accum = bias_data[o];
+ for (int i = 0; i < input_depth; i++) {
+ accum +=
+ input_data[b * input_depth + i] * weights_data[o * input_depth + i];
+ }
+ accum = std::min(accum, kOutputMax);
+ accum = std::max(accum, kOutputMin);
+ expected_output_data[b * output_depth + o] = accum;
+ }
+ }
+
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f)));
+}
+
+TEST_P(QuantizedFullyConnectedOpTest,
+ SimpleTestQuantizedInt16OutputDefaultWeights) {
+ for (int input_depth : {1, 3, 10, 100}) {
+ for (int output_depth : {1, 3, 10, 100}) {
+ for (int batch : {1, 3, 10, 100}) {
+ SimpleTestQuantizedInt16OutputCase(
+ GetRegistration(), input_depth, output_depth, batch,
+ FullyConnectedOptionsWeightsFormat_DEFAULT);
+ }
+ }
+ }
+}
+
+TEST_P(QuantizedFullyConnectedOpTest,
+ SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) {
+ // The shuffled weights block shape is 4x16. The shape of the weights matrix
+ // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16.
+ // This means that output_depth must be a multiple of 4, and input_deth must
+ // be a multiple of 16.
+ for (int input_depth_numblocks : {1, 3}) {
+ for (int output_depth_numblocks : {1, 3}) {
+ int input_depth = 16 * input_depth_numblocks;
+ int output_depth = 4 * output_depth_numblocks;
+ // The fast shuffled path is currently supporting only batch sizes of 1
+ // and 4. The idea is that the whole point of that path is to go as fast
+ // as possible for small batch size, which requires fully specializing
+ // it for each batch size, and for larger batch sizes the generic
+ // gemmlowp-based implementation is fast enough.
+ for (int batch : {1, 4}) {
+ SimpleTestQuantizedInt16OutputCase(
+ GetRegistration(), input_depth, output_depth, batch,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
+ }
+ }
+ }
}
TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
@@ -396,11 +605,11 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
/*max_abs_error=*/1.3f)));
}
-TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) {
+TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in
// batches. In this case, we need the input to have multiples of '2'.
- FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
+ FloatFullyConnectedOpModel m(GetRegistration(),
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
m.SetWeights({
@@ -444,11 +653,44 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) {
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
- 24, 25, 26, //
- 58, 59, 60, //
- })));
- EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+TEST_P(QuantizedFullyConnectedOpTest,
+ SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1) {
+ // real_multiplier = 2.
+ QuantizedFullyConnectedOpModel m(
+ GetRegistration(), /*units=*/3, /*batches=*/2,
+ /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128},
+ /*output=*/{TensorType_UINT8, {}, -63.5, 64});
+
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, // first batch
+ 58, 59, 60, // second batch
+ })));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAre(175, 177, 179, 243, 245, 247));
}
INSTANTIATE_TEST_CASE_P(
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index c452d3ebac..b5afeb1a7b 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -40,10 +40,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Only INT32 positions are supported.
TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32);
- // Check that input and output types match.
- TF_LITE_ENSURE_EQ(context, input->type, output->type);
- // TODO(mgubin): only 0D or 1D positions are currently supported.
- TF_LITE_ENSURE(context, NumDimensions(positions) <= 1);
+ // Assign to output the input type.
+ output->type = input->type;
// TODO(mgubin): Only default axis == 0 is supported.
TF_LITE_ENSURE_EQ(context, params->axis, 0);
// Check conditions for different types.
@@ -59,8 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
} break;
default:
- context->ReportError(context,
- "Only float32 and string types are supported");
+ context->ReportError(
+ context, "Only float32 and string types are supported, got %d",
+ input->type);
return kTfLiteError;
}
const int num_dimensions =
@@ -85,11 +84,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int input_rank = NumDimensions(input);
-#define TF_LITE_GATHER(data_type, index_type) \
- optimized_ops::Gather( \
- GetTensorData<data_type>(input), GetTensorDims(input), input_rank, \
- GetTensorData<index_type>(positions), GetTensorDims(positions), \
- GetTensorData<data_type>(output), GetTensorDims(output));
+#define TF_LITE_GATHER(data_type, index_type) \
+ { \
+ tflite::GatherParams op_params; \
+ op_params.input_rank = input_rank; \
+ optimized_ops::Gather( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(positions), GetTensorData<index_type>(positions), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_GATHER(float, int32_t);
@@ -101,6 +104,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_GATHER(int32_t, int32_t);
break;
case kTfLiteString: {
+ // TODO(mgubin): Currently support only for 1D output tensors.
DynamicBuffer buffer;
const int32* indexes = positions->data.i32;
const int num_strings = GetStringCount(input);
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index cdadbeda18..1b48884e09 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) {
EXPECT_TRUE(m.GetOutputShape().empty());
}
+TEST(GatherOpTest, Test2DIndexWith2DResult) {
+ GatherOpModel m({3}, TensorType_FLOAT32, {1, 2});
+ m.SetInputFloat({1.0, 2.0, 3.0});
+ m.SetPositions({1, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0, 1.0})));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+}
+
TEST(FloatGatherOpTest, Duplicate) {
GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2});
m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
index 95f45ea768..ed334af2da 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.cc
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -14,57 +14,70 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace gemm_support {
+namespace {
-struct RefCountedGemmContext {
- gemmlowp::GemmContext* gemm_context_ = nullptr;
- int num_references_ = 0;
+struct RefCountedGemmContext : public TfLiteExternalContext {
+ std::unique_ptr<gemmlowp::GemmContext> gemm_context;
+ int num_references = 0;
};
+RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedGemmContext*>(
+ context->GetExternalContext(context, kTfLiteGemmLowpContext));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ auto* ptr = GetGemmLowpContext(context);
+ if (ptr != nullptr) {
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
ptr = new RefCountedGemmContext;
- ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->type = kTfLiteGemmLowpContext;
+ ptr->Refresh = Refresh;
+ ptr->gemm_context.reset(new gemmlowp::GemmContext());
if (context->recommended_num_threads != -1) {
- ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
}
- ptr->num_references_ = 0;
- context->gemm_context = ptr;
+ ptr->num_references = 0;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
}
- ptr->num_references_++;
+ ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
"IncrementUsageCounter()");
}
- if (--ptr->num_references_ == 0) {
- delete ptr->gemm_context_;
+ if (--ptr->num_references == 0) {
delete ptr;
- context->gemm_context = nullptr;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr);
}
}
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
}
- return ptr->gemm_context_;
-}
-
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- GetFromContext(context)->set_max_num_threads(num_threads);
- DecrementUsageCounter(context);
+ return ptr->gemm_context.get();
}
} // namespace gemm_support
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index f033501cb6..43cd2b3055 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
#include "public/gemmlowp.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace gemm_support {
@@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context);
// 'context'. If there are no more usages the GemmContext will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the number of threads that can be used by gemmlowp.
-void SetNumThreads(TfLiteContext* context, int num_threads);
-
} // namespace gemm_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index 41211d41aa..c0b3c3c0c5 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -31,7 +31,6 @@ limitations under the License.
// Each item indicates whether the corresponding lookup has a returned value.
// 0 for missing key, 1 for found key.
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -40,8 +39,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index aabbb0685c..afb5ec05df 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -43,6 +43,10 @@ cc_library(
"compatibility.h",
"types.h",
],
+ deps = [
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "@com_google_absl//absl/base:core_headers",
+ ],
)
config_setting(
@@ -160,9 +164,45 @@ cc_library(
":types",
":reference_base",
":round",
+ ":tensor_utils",
+ "//third_party/eigen3",
+ "@gemmlowp",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ ":darwin_x86_64": tflite_deps_intel,
+ ":freebsd": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "legacy_optimized_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "optimized/depthwiseconv_float.h",
+ "optimized/depthwiseconv_uint8.h",
+ "optimized/depthwiseconv_uint8_3x3_filter.h",
+ "optimized/legacy_optimized_ops.h",
+ "optimized/optimized_ops.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":quantization_util",
+ ":strided_slice_logic",
+ ":tensor_utils",
+ ":types",
+ ":legacy_reference_base",
+ ":round",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -184,13 +224,15 @@ cc_library(
"optimized/eigen_spatial_convolutions.h",
"optimized/eigen_tensor_reduced_instantiations_oss.h",
"optimized/multithreaded_conv.h",
+ # FIXME(petewarden) - This should be removed, since it's a header from the
+ # :tensor dependency below.
"tensor.h",
],
deps = [
":optimized_base",
+ ":tensor",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//third_party/eigen3",
],
)
@@ -198,8 +240,9 @@ cc_library(
cc_test(
name = "tensor_test",
srcs = ["tensor_test.cc"],
+ tags = ["no_oss"],
deps = [
- ":reference",
+ ":tensor",
"@com_google_googletest//:gtest",
],
)
@@ -220,12 +263,14 @@ cc_library(
deps = [
":round",
":types",
+ "//tensorflow/contrib/lite/kernels:op_macros",
],
)
cc_test(
name = "quantization_util_test",
srcs = ["quantization_util_test.cc"],
+ tags = ["no_oss"],
deps = [
":quantization_util",
"@com_google_googletest//:gtest",
@@ -250,16 +295,18 @@ cc_library(
"common.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",
+ "reference/fully_connected.h",
"reference/reference_ops.h",
+ "reference/softmax.h",
],
deps = [
":quantization_util",
":round",
":strided_slice_logic",
":types",
- "//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:op_macros",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -274,11 +321,60 @@ cc_library(
)
cc_library(
+ name = "legacy_reference_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "reference/depthwiseconv_float.h",
+ "reference/depthwiseconv_uint8.h",
+ "reference/fully_connected.h",
+ "reference/legacy_reference_ops.h",
+ "reference/reference_ops.h",
+ "reference/softmax.h",
+ ],
+ deps = [
+ ":quantization_util",
+ ":round",
+ ":strided_slice_logic",
+ ":types",
+ "@gemmlowp",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ ":darwin_x86_64": tflite_deps_intel,
+ ":freebsd": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "tensor",
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+# Deprecated version of :tensor, kept for backwards compatibility.
+cc_library(
name = "reference",
- hdrs = ["tensor.h"],
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
deps = [
":types",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -292,7 +388,7 @@ cc_library(
],
deps = [
":round",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
],
@@ -317,7 +413,7 @@ cc_library(
":cpu_check",
":round",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
"@arm_neon_2_x86_sse",
@@ -331,7 +427,7 @@ cc_library(
hdrs = ["kernel_utils.h"],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -373,9 +469,10 @@ cc_library(
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
- "//tensorflow/contrib/lite/kernels:activation_functor",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "@com_google_absl//absl/base:core_headers",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
+ "//tensorflow/contrib/lite/kernels:op_macros",
"@gemmlowp",
] + select({
":arm": [
@@ -414,12 +511,25 @@ cc_library(
":darwin": [
":neon_tensor_utils",
],
+ ":darwin_x86_64": [
+ ":neon_tensor_utils",
+ ],
"//conditions:default": [
":portable_tensor_utils",
],
}),
)
+cc_library(
+ name = "test_util",
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
cc_test(
name = "tensor_utils_test",
srcs = ["tensor_utils_test.cc"],
@@ -431,15 +541,112 @@ cc_test(
"//conditions:default": [],
}),
linkstatic = 1,
- tags = ["tflite_not_portable_ios"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest_main",
],
)
+cc_test(
+ name = "depthwiseconv_float_test",
+ srcs = ["depthwiseconv_float_test.cc"],
+ tags = ["no_oss"],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ ":test_util",
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "depthwiseconv_quantized_test",
+ srcs = ["depthwiseconv_quantized_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ ":test_util",
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "resize_bilinear_test",
+ srcs = ["resize_bilinear_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ ":test_util",
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "softmax_quantized_test",
+ timeout = "long",
+ srcs = [
+ "softmax_quantized_test.cc",
+ ],
+ tags = ["no_oss"],
+ deps = [
+ ":optimized_base",
+ ":quantization_util",
+ ":reference_base",
+ ":test_util",
+ "//tensorflow/contrib/lite:string",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "logsoftmax_quantized_test",
+ timeout = "long",
+ srcs = [
+ "logsoftmax_quantized_test.cc",
+ ],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":optimized_base",
+ ":quantization_util",
+ ":reference_base",
+ ":test_util",
+ "//tensorflow/contrib/lite:string",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "log_quantized_test",
+ srcs = ["log_quantized_test.cc"],
+ tags = ["no_oss"],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ "//tensorflow/contrib/lite:string",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
cc_library(
name = "cpu_check",
hdrs = [
@@ -459,6 +666,7 @@ cc_library(
cc_test(
name = "batch_to_space_nd_test",
srcs = ["batch_to_space_nd_test.cc"],
+ tags = ["no_oss"],
deps = [
":optimized_base",
"@com_google_googletest//:gtest_main",
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index ede95dfee0..e67fee11b8 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -45,7 +45,7 @@ limitations under the License.
#endif
#endif
-#include "public/gemmlowp.h"
+#include "fixedpoint/fixedpoint.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
@@ -87,12 +87,12 @@ float ActivationFunction(float x) {
output_activation_max);
}
-inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
- int32 x, int32 quantized_multiplier, int right_shift) {
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ int32 x, int32 quantized_multiplier, int left_shift) {
using gemmlowp::RoundingDivideByPOT;
using gemmlowp::SaturatingRoundingDoublingHighMul;
return RoundingDivideByPOT(
- SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
}
inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
@@ -117,6 +117,9 @@ template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
"Only unsigned integer types handled.");
+#if defined(__GNUC__)
+ return integer_input ? __builtin_clz(integer_input) : 0;
+#else
const T one_in_leading_positive = static_cast<T>(1)
<< (std::numeric_limits<T>::digits - 1);
int leading_zeros = 0;
@@ -125,6 +128,140 @@ int CountLeadingZeros(T integer_input) {
++leading_zeros;
}
return leading_zeros;
+#endif
+}
+
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(
+ const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
+ NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
+ auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
+
+ // Copy dims to desc, calculating strides.
+ int desc0_stride = 1;
+ int desc1_stride = 1;
+ for (int i = N - 1; i >= 0; --i) {
+ desc0_out->extents[i] = extended_input0_shape.Dims(i);
+ desc0_out->strides[i] = desc0_stride;
+ desc0_stride *= extended_input0_shape.Dims(i);
+ desc1_out->extents[i] = extended_input1_shape.Dims(i);
+ desc1_out->strides[i] = desc1_stride;
+ desc1_stride *= extended_input1_shape.Dims(i);
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = extended_input0_shape.Dims(i);
+ const int extent1 = extended_input1_shape.Dims(i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
index 93fc6b6a76..7c176e0fa1 100644
--- a/tensorflow/contrib/lite/kernels/internal/compatibility.h
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -15,65 +15,65 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
-#include <cassert>
#include <cstdint>
-#include <cstdlib>
+
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
#ifndef TFLITE_DCHECK
-#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false)
+#define TFLITE_DCHECK(condition) (condition) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_EQ
-#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_NE
-#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_GE
-#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_GT
-#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_LE
-#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_LT
-#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
// TODO(ahentz): Clean up: We should stick to the DCHECK versions.
#ifndef TFLITE_CHECK
-#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort()
+#define TFLITE_CHECK(condition) (condition) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_EQ
-#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_NE
-#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_GE
-#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_GT
-#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_LE
-#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_LT
-#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ABORT
#endif
// TODO(ahentz): Clean up.
@@ -84,4 +84,27 @@ using uint16 = std::uint16_t;
using int32 = std::int32_t;
using uint32 = std::uint32_t;
+// TFLITE_DEPRECATED()
+//
+// Duplicated from absl/base/macros.h to avoid pulling in that library.
+// Marks a deprecated class, struct, enum, function, method and variable
+// declarations. The macro argument is used as a custom diagnostic message (e.g.
+// suggestion of a better alternative).
+//
+// Example:
+//
+// class TFLITE_DEPRECATED("Use Bar instead") Foo {...};
+// TFLITE_DEPRECATED("Use Baz instead") void Bar() {...}
+//
+// Every usage of a deprecated entity will trigger a warning when compiled with
+// clang's `-Wdeprecated-declarations` option. This option is turned off by
+// default, but the warnings will be reported by clang-tidy.
+#if defined(__clang__) && __cplusplus >= 201103L
+#define TFLITE_DEPRECATED(message) __attribute__((deprecated(message)))
+#endif
+
+#ifndef TFLITE_DEPRECATED
+#define TFLITE_DEPRECATED(message)
+#endif
+
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
new file mode 100644
index 0000000000..41862a21a6
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -0,0 +1,157 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+
+namespace tflite {
+namespace {
+
+// Runs the DepthwiseConv and compares against the reference implementation.
+void TestOneDepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape) {
+ const int output_buffer_size = output_shape.FlatSize();
+ std::vector<float> output_data(output_buffer_size);
+ std::vector<float> reference_output_data(output_buffer_size);
+ reference_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ reference_output_data.data());
+ optimized_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data.data());
+
+ double sum_abs_diff = 0;
+ float max_abs_val = 0;
+ for (int i = 0; i < output_buffer_size; i++) {
+ sum_abs_diff += std::abs(output_data[i] - reference_output_data[i]);
+ max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i]));
+ }
+ if (sum_abs_diff != 0.f) {
+ const float mean_diff =
+ static_cast<float>(sum_abs_diff / output_buffer_size);
+ const float relative_error = std::abs(mean_diff) / max_abs_val;
+ ASSERT_LT(relative_error, 1e-5f);
+ }
+}
+
+// This function picks some random DepthwiseConv params, which may or may not
+// be legal. If they're not legal, it returns false. If they're legal,
+// it runs the DepthwiseConv test and returns true. This allows the caller
+// to loop until a test has been run.
+bool TryTestOneDepthwiseConv() {
+ // We have to pick a lot of positive values, where we are particularly
+ // interested in small values because they are most likely to be special
+ // cases in optimized implementations, and secondarily because they allow
+ // tests to run fast, which means we can run more tests and get more
+ // coverage.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const int output_depth = input_depth * depth_multiplier;
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ float output_activation_min, output_activation_max;
+ FusedActivationFunctionType ac =
+ RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone,
+ FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu1,
+ FusedActivationFunctionType::kRelu6}));
+ GetActivationMinMax(ac, &output_activation_min, &output_activation_max);
+ // The optimized DepthwiseConv implementation currently uses a fixed-size
+ // accumulator buffer on the stack, with that size. This currently means
+ // that it does not support larger output depths. It CHECK's for it,
+ // so it's safe in the sense that if a larger output depth was encountered,
+ // it would explicitly fail. We just need to adjust our testing to that
+ // constraint.
+ const int kMaxSupportedOutputDepth = 1024;
+ if (output_depth > kMaxSupportedOutputDepth) {
+ return false;
+ }
+ RuntimeShape input_shape_inference(
+ {batch, input_height, input_width, input_depth});
+ RuntimeShape output_shape_inference;
+ int pad_width, pad_height;
+ const auto padding_type =
+ UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
+ if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
+ &output_shape_inference, &pad_width, &pad_height)) {
+ return false;
+ }
+ RuntimeShape filter_shape_inference(
+ {1, filter_height, filter_width, output_depth});
+ RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
+ const int input_buffer_size = input_shape_inference.FlatSize();
+ const int filter_buffer_size = filter_shape_inference.FlatSize();
+ std::vector<float> input_data(input_buffer_size);
+ std::vector<float> filter_data(filter_buffer_size);
+ std::vector<float> bias_data(output_depth);
+ const float input_amplitude = 1.f;
+ const float filter_amplitude = 1.f;
+ const float bias_amplitude =
+ filter_width * filter_height * input_amplitude * filter_amplitude;
+ FillRandom(&input_data, -input_amplitude, input_amplitude);
+ FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
+ FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride;
+ op_params.stride_height = stride;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ TestOneDepthwiseConv(op_params, input_shape_inference, input_data.data(),
+ filter_shape_inference, filter_data.data(),
+ bias_shape_inference, bias_data.data(),
+ output_shape_inference);
+ return true;
+}
+
+void TestOneDepthwiseConv() {
+ while (!TryTestOneDepthwiseConv()) {
+ }
+}
+
+TEST(TestDepthwiseConv, TestDepthwiseConv) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ TestOneDepthwiseConv();
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
new file mode 100644
index 0000000000..9414e109c3
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -0,0 +1,349 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <iterator>
+#include <limits>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+
+namespace tflite {
+namespace {
+
+// Runs the DepthwiseConv and compares against the reference implementation.
+template <FusedActivationFunctionType Ac>
+int TestOneDepthwiseConvWithGivenOutputShift(
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
+ std::int32_t input_offset, const std::uint8_t* filter_data,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ std::int32_t output_offset, std::int32_t output_multiplier,
+ int output_shift, std::int32_t output_activation_min,
+ std::int32_t output_activation_max, const RuntimeShape& output_shape) {
+ const int output_buffer_size = output_shape.FlatSize();
+ std::vector<std::uint8_t> output_data(output_buffer_size);
+ std::vector<std::uint8_t> reference_output_data(output_buffer_size);
+
+ tflite::DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride;
+ op_params.stride_height = stride;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = -output_shift;
+ reference_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ reference_output_data.data());
+ optimized_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data.data());
+ int saturated_min = 0;
+ int saturated_max = 0;
+ std::vector<int> diff(output_buffer_size);
+ std::int64_t sum_diff = 0;
+ std::int64_t sum_abs_diff = 0;
+ for (int i = 0; i < output_buffer_size; i++) {
+ diff[i] = static_cast<int>(output_data[i]) -
+ static_cast<int>(reference_output_data[i]);
+ sum_diff += diff[i];
+ sum_abs_diff += std::abs(diff[i]);
+ saturated_min += output_data[i] == output_activation_min;
+ saturated_max += output_data[i] == output_activation_max;
+ }
+ // These stats help understand test failures.
+ std::sort(std::begin(diff), std::end(diff));
+ const int min_diff = diff.front();
+ const int max_diff = diff.back();
+ const int median_diff = diff[diff.size() / 2];
+ const float mean_diff = static_cast<float>(sum_diff) / output_buffer_size;
+ const float mean_abs_diff =
+ static_cast<float>(sum_abs_diff) / output_buffer_size;
+ // Normally we should require bit-for-bit exact results. Unfortunately a bug
+ // in the Intel arm_neon_sse.h translation header that we use for x86 tests
+ // causes 1-bit inaccuracy in
+ // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized
+ // DepthwiseConv ops. So we have to live with a few off-by-one errors for now,
+ // yet still ensure that no more than a small minority of values are wrong.
+ EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f &&
+ std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
+ std::abs(max_diff) <= 1);
+ if (saturated_min > 2 * saturated_max) {
+ return -1;
+ }
+ if (saturated_max > 2 * saturated_min) {
+ return 1;
+ }
+ return 0;
+}
+
+// The point of this function is that we can't practically know which
+// output_shift value to pass to test DepthwiseConv. It's not easy to guess (we
+// could do some
+// statistics for large size, but they would be fragile at smaller sizes), and
+// guessing wrong would mean that all the values get saturated so the test
+// becomes
+// vacuous. So we just bisect our way to reasonable output_shift values.
+template <FusedActivationFunctionType Ac>
+void TestOneDepthwiseConvBisectOutputShift(
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
+ std::int32_t input_offset, const std::uint8_t* filter_data,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ std::int32_t output_offset, std::int32_t output_multiplier,
+ int output_activation_bisect_start, int output_activation_bisect_end,
+ std::int32_t output_activation_min, std::int32_t output_activation_max,
+ const RuntimeShape& output_shape) {
+ ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end)
+ << "Bisection failed ?!?!";
+ int output_shift_bisect_midpoint =
+ (output_activation_bisect_start + output_activation_bisect_end) / 2;
+ int bisect_result = TestOneDepthwiseConvWithGivenOutputShift<Ac>(
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier,
+ output_shift_bisect_midpoint, output_activation_min,
+ output_activation_max, output_shape);
+ // At this point we know that the test succeeded (otherwise it would have
+ // aborted).
+ if (bisect_result == 0) {
+ // The result isn't particularly saturated on one or the other side.
+ // All good, we're done.
+ return;
+ }
+ if (output_activation_bisect_start == output_activation_bisect_end - 1) {
+ // There is still some saturation on one side, but the bisection is
+ // finished anyways. We're done; nothing more we can do about it. This
+ // happens
+ // in particular when using an activation with a narrow range.
+ return;
+ }
+ // Continue the bisection based on the present result.
+ int new_output_activation_bisect_start = bisect_result == 1
+ ? output_shift_bisect_midpoint
+ : output_activation_bisect_start;
+ int new_output_activation_bisect_end = bisect_result == 1
+ ? output_activation_bisect_end
+ : output_shift_bisect_midpoint;
+ TestOneDepthwiseConvBisectOutputShift<Ac>(
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier,
+ new_output_activation_bisect_start, new_output_activation_bisect_end,
+ output_activation_min, output_activation_max, output_shape);
+}
+
+template <FusedActivationFunctionType Ac>
+void TestOneDepthwiseConv(
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
+ std::int32_t input_offset, const std::uint8_t* filter_data,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ std::int32_t output_offset, std::int32_t output_multiplier,
+ std::int32_t output_activation_min, std::int32_t output_activation_max,
+ const RuntimeShape& output_shape) {
+ TestOneDepthwiseConvBisectOutputShift<Ac>(
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier, 0, 32,
+ output_activation_min, output_activation_max, output_shape);
+}
+
+void TestOneDepthwiseConv(
+ FusedActivationFunctionType Ac, const std::uint8_t* input_data,
+ const RuntimeShape& input_shape, std::int32_t input_offset,
+ const std::uint8_t* filter_data, const RuntimeShape& filter_shape,
+ std::int32_t filter_offset, const std::int32_t* bias_data,
+ const RuntimeShape& bias_shape, int stride, int pad_width, int pad_height,
+ int depth_multiplier, std::int32_t output_offset,
+ std::int32_t output_multiplier, std::int32_t output_activation_min,
+ std::int32_t output_activation_max, const RuntimeShape& output_shape) {
+#define TOCO_HANDLE_CASE(AC_TYPE) \
+ if (AC_TYPE == Ac) { \
+ TestOneDepthwiseConv<AC_TYPE>( \
+ input_data, input_shape, input_offset, filter_data, filter_shape, \
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height, \
+ depth_multiplier, output_offset, output_multiplier, \
+ output_activation_min, output_activation_max, output_shape); \
+ return; \
+ }
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
+#undef TOCO_HANDLE_CASE
+}
+
+bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
+ int input_height, int filter_width, int filter_height,
+ int depth_multiplier, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type) {
+ const int output_depth = input_depth * depth_multiplier;
+ // The optimized DepthwiseConv implementation currently uses a fixed-size
+ // accumulator buffer on the stack, with that size. This currently means
+ // that it does not support larger output depths. It CHECK's for it,
+ // so it's safe in the sense that if a larger output depth was encountered,
+ // it would explicitly fail. We just need to adjust our testing to that
+ // constraint.
+ const int kMaxSupportedOutputDepth = 1024;
+ if (output_depth > kMaxSupportedOutputDepth) {
+ return false;
+ }
+ const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu6,
+ FusedActivationFunctionType::kRelu1}));
+ int output_activation_min = 0;
+ int output_activation_max = 255;
+ if (ac != FusedActivationFunctionType::kNone && UniformRandomInt(0, 1)) {
+ output_activation_min = UniformRandomInt(0, 50);
+ output_activation_max = UniformRandomInt(200, 255);
+ }
+ const std::int32_t output_multiplier =
+ UniformRandomInt(1 << 29, std::numeric_limits<std::int32_t>::max());
+ const std::int32_t input_offset = UniformRandomInt(-256, 0);
+ const std::int32_t filter_offset = UniformRandomInt(-256, 0);
+ const std::int32_t output_offset = UniformRandomInt(-256, 0);
+ RuntimeShape input_shape_inference(
+ {batch, input_height, input_width, input_depth});
+ RuntimeShape output_shape_inference;
+ int pad_width, pad_height;
+ if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
+ &output_shape_inference, &pad_width, &pad_height)) {
+ return false;
+ }
+ RuntimeShape filter_shape_inference(
+ {1, filter_height, filter_width, output_depth});
+ RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
+ const int input_buffer_size = input_shape_inference.FlatSize();
+ const int filter_buffer_size = filter_shape_inference.FlatSize();
+ std::vector<std::uint8_t> input_data(input_buffer_size);
+ std::vector<std::uint8_t> filter_data(filter_buffer_size);
+ std::vector<std::int32_t> bias_data(output_depth);
+ FillRandom(&input_data);
+ FillRandom(&filter_data);
+ FillRandom(&bias_data, -10000, 10000);
+ TestOneDepthwiseConv(ac, input_data.data(), input_shape_inference,
+ input_offset, filter_data.data(), filter_shape_inference,
+ filter_offset, bias_data.data(), bias_shape_inference,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_activation_min,
+ output_activation_max, output_shape_inference);
+ return true;
+}
+
+// This function picks some random DepthwiseConv params, which may or may not
+// be legal. If they're not legal, it returns false. If they're legal,
+// it runs the DepthwiseConv test and returns true. This allows the caller
+// to loop until a test has been run.
+bool TryTestOneDepthwiseConv() {
+ // We have to pick a lot of positive values, where we are particularly
+ // interested in small values because they are most likely to be special
+ // cases in optimized implementations, and secondarily because they allow
+ // tests to run fast, which means we can run more tests and get more
+ // coverage.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const auto padding_type =
+ UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
+
+ return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
+ filter_width, filter_height, depth_multiplier,
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
+}
+
+// Tests parameters for the 3x3 filter kernel.
+bool TryTestOneDepthwiseConv3x3Filter() {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int filter_width = 3;
+ const int filter_height = 3;
+ const int depth_multiplier = 1;
+ const int stride = UniformRandomInt(1, 2);
+ // We don't support dilations in the 3x3 filter.
+ const int dilation_width_factor = 1;
+ const int dilation_height_factor = 1;
+ // Although the kernel supports only kValid padding, we test that kSame
+ // is using the correct code path.
+ const auto padding_type =
+ UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
+
+ return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
+ filter_width, filter_height, depth_multiplier,
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
+}
+
+void TestOneDepthwiseConv() {
+ while (!TryTestOneDepthwiseConv()) {
+ }
+}
+
+void TestOneDepthwiseConv3x3Filter() {
+ while (!TryTestOneDepthwiseConv3x3Filter()) {
+ }
+}
+
+TEST(TestDepthwiseConv, TestDepthwiseConv) {
+ const int kTestsToRun = 10 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ TestOneDepthwiseConv();
+ }
+}
+
+TEST(TestDepthwiseConv3x3Filter, TestDepthwiseConv) {
+ const int kTestsToRun = 3 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ TestOneDepthwiseConv3x3Filter();
+ }
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 5f9cfc450d..083e5839bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
-#include <algorithm>
-
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
namespace tflite {
@@ -26,6 +24,21 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
int input_size, int num_units, int batch_size,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr, recurrent_weights_ptr,
+ bias_ptr, input_size, /*aux_input_size=*/0, num_units,
+ batch_size, activation, hidden_state_ptr_batch,
+ output_ptr_batch);
+}
+
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -33,6 +46,12 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
output_ptr_batch, /*result_stride=*/1);
+ // Output += aux_input * aux_input_weights (if they are not empty).
+ if (aux_input_size > 0) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size, aux_input_ptr_batch,
+ batch_size, output_ptr_batch, /*result_stride=*/1);
+ }
// Output += recurrent_weights * hidden_state
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
@@ -52,21 +71,41 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
TfLiteFusedActivation activation,
int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch,
- float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ float* scaling_factors, float* hidden_state_ptr_batch,
+ float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr,
+ /*aux_input_weights_scale=*/0.0f, recurrent_weights_ptr,
+ recurrent_weights_scale, bias_ptr, input_size,
+ /*aux_input_size=*/0, num_units, batch_size, activation,
+ quantized_input_ptr_batch,
+ /*aux_quantized_input_ptr_batch=*/nullptr,
+ quantized_hidden_state_ptr_batch, scaling_factors,
+ hidden_state_ptr_batch, output_ptr_batch);
+}
+
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
- // TODO(mirkov): change std::minmax_element with a vectorized call.
- auto minmax_element = std::minmax_element(
- input_ptr_batch, input_ptr_batch + batch_size * input_size);
-
// Save quantization and matmul computation for all zero input.
- if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) {
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
// Quantize input from float to uint8 + quantization params (scaling
// factor).
float unused_min, unused_max;
- float* scaling_factors = new float[batch_size];
+ // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function)
+ // whichever is faster.
for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats(
@@ -80,16 +119,33 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1);
- delete[] scaling_factors;
}
- minmax_element = std::minmax_element(
- hidden_state_ptr_batch, hidden_state_ptr_batch + batch_size * num_units);
+ if (aux_input_ptr_batch &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch,
+ batch_size * aux_input_size)) {
+ float unused_min, unused_max;
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * aux_input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, aux_input_size,
+ aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ scaling_factors[b] *= aux_input_weights_scale;
+ }
+
+ // Output += aux_input * aux_input_weights
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size,
+ aux_quantized_input_ptr_batch, scaling_factors, batch_size,
+ output_ptr_batch, /*result_stride=*/1);
+ }
+
// Save quantization and matmul computation for all zero input.
- if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) {
+ if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
+ batch_size * num_units)) {
// Quantize hidden_state
float unused_min, unused_max;
- float* scaling_factors = new float[batch_size];
for (int b = 0; b < batch_size; ++b) {
const int offset = b * num_units;
tensor_utils::SymmetricQuantizeFloats(
@@ -104,7 +160,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
recurrent_weights_ptr, num_units, num_units,
quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
output_ptr_batch, /*result_stride=*/1);
- delete[] scaling_factors;
}
// Output = activation(Output) and update hidden_state
@@ -114,152 +169,5 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
hidden_state_ptr_batch);
}
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch,
- /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_batch, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index cbfbcbeefc..74e0a4a53d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
namespace kernel_utils {
@@ -35,12 +35,24 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch);
+// Same as above but includes an auxiliary input with the corresponding weights.
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs a quantized RNN batch inference step. Same as above, but for
// quantization purposes, we also pass in quantized_hidden_state_ptr_batch and
// quantized_input_ptr_batch pointers for temporary storage of the quantized
// values of hidden_state_ptr_batch and input_ptr_batch, respectively.
// These temporary storages are expected to be preallocated to the same size as
// the respective pointers.
+// An additional preallocated temporary storage 'scaling_factors' (of size
+// batch_size) is used to store the scaling factors of the quantization (used
+// for recovery).
// {input,recurrent}_weights_scale params are used for dequantization/recovery.
void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
float input_weights_scale,
@@ -50,43 +62,19 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
TfLiteFusedActivation activation,
int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch,
- float* hidden_state_ptr_batch, float* output_ptr_batch);
+ float* scaling_factors, float* hidden_state_ptr_batch,
+ float* output_ptr_batch);
-// Performs an LSTM batch inference step for input specified by input_ptr_batch.
-// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
-// biases (*_bias_ptr), and buffers (*_scratch), along with additional
-// parameters:
-// - params: various LSTM params including activation, clipping, etc.,
-// - n_batch: size of batch,
-// - n_cell: number of cells (or units),
-// - n_input: the input size,
-// - n_output: the output size.
-//
-// The pointers to the cell and output state and the output are updated. Unless
-// projection is specified output and output state contain the same data.
-//
-// The pointers with the suffix "_batch" point to data aligned in batch_major
-// order, and each step processes batch_size many inputs from input_ptr_batch,
-// and updates batch_size many cell and output states.
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc
new file mode 100644
index 0000000000..8963abb9af
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+
+class NumberGenerator {
+ public:
+ std::vector<int> RandomIntVector(int n, int min_val, int max_val) {
+ std::vector<int> vec(n);
+ double scale = static_cast<double>(max_val + 1 - min_val) / engine_.max();
+ for (auto& it : vec) {
+ it = min_val + std::floor(engine_() * scale);
+ }
+ return vec;
+ }
+
+ std::mt19937 engine_;
+};
+
+class LogQuantizedTest : public ::testing::Test {
+ public:
+ NumberGenerator generator_;
+};
+
+// input_integer_bits <= 30. output_integer_bits > 0.
+inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits,
+ int output_integer_bits) {
+ const double float_log_sum_of_exps = std::log(
+ static_cast<double>(input_val) * 0.5 / (1 << (30 - input_integer_bits)));
+ static constexpr double min_int =
+ static_cast<double>(std::numeric_limits<int32>::min());
+ static constexpr double max_int =
+ static_cast<double>(std::numeric_limits<int32>::max());
+ double double_result = tflite::TfLiteRound(float_log_sum_of_exps *
+ (1 << (31 - output_integer_bits)));
+ return static_cast<std::int32_t>(
+ std::min(max_int, std::max(min_int, double_result)));
+}
+
+void CheckOutputData(const std::vector<int32>& test_output,
+ const std::vector<int32>& reference_output,
+ const std::vector<int32>& test_input,
+ const string& check_label, int input_integer_bits,
+ int output_integer_bits, int tolerance) {
+ // In the special case of small input, specifically raw value of 5, a rounding
+ // up leads to difference in the output. We do not aim to be accurate for
+ // very small input values, and there should be sufficient input fractional
+ // bits that this is a small input.
+ static constexpr double error_from_rounding_up = 0.0224585;
+ const int n = test_output.size();
+ ASSERT_EQ(n, reference_output.size());
+ for (int i = 0; i < n; ++i) {
+ // Adjust tolerance when input <= 5*2^-(31-input_integer_bits).
+ const int adjusted_tolerance =
+ test_input[i] > 5
+ ? tolerance
+ : std::max(tolerance, static_cast<int>(std::ceil(
+ error_from_rounding_up *
+ (1 << (31 - output_integer_bits)))));
+ ASSERT_LE(std::abs(test_output[i] - reference_output[i]),
+ adjusted_tolerance)
+ << "Failure in \"" << check_label << "\" at i=" << i
+ << ", test_input[i]=" << test_input[i] << "="
+ << static_cast<double>(test_input[i]) / (1 << (31 - input_integer_bits))
+ << ", test_output[i]=" << test_output[i] << "="
+ << static_cast<double>(test_output[i]) /
+ (1 << (31 - output_integer_bits))
+ << ", reference_output[i]=" << reference_output[i] << "="
+ << static_cast<double>(reference_output[i]) /
+ (1 << (31 - output_integer_bits))
+ << ", difference[i]=" << std::abs(reference_output[i] - test_output[i])
+ << "="
+ << static_cast<double>(std::abs(reference_output[i] - test_output[i])) /
+ (1 << (31 - output_integer_bits))
+ << "; tolerance=" << tolerance
+ << ", adj tolerance=" << adjusted_tolerance;
+ }
+}
+
+void RightShiftVector(const std::vector<int32>& shifts,
+ std::vector<int32>* vec) {
+ const int n = vec->size();
+ ASSERT_EQ(n, shifts.size());
+ for (int i = 0; i < n; ++i) {
+ vec->at(i) = std::max(1, vec->at(i) >> shifts[i]);
+ }
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+void RunSingleTest(const std::vector<int32>& test_input,
+ const string& check_label, int tolerance) {
+ const int n = test_input.size();
+ std::vector<int32> float_gen_output(n, 0);
+ std::vector<int32> reference_output(n, 0);
+ std::vector<int32> optimized_output(n, 0);
+
+ // Workaround the stupid things that intelligent humans do.
+ // Consequence of __builtin_clz(0u) may equal 31 instead of 32.
+ std::vector<int32> fudged_input(n, 0);
+ for (int i = 0; i < n; ++i) {
+ fudged_input[i] = std::max(test_input[i], 2);
+ }
+
+ for (int i = 0; i < n; ++i) {
+ reference_output[i] =
+ tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
+ OutputIntegerBits, InputIntegerBits>(
+ gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
+ fudged_input[i]))
+ .raw();
+ optimized_output[i] =
+ tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
+ OutputIntegerBits, InputIntegerBits>(
+ gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
+ fudged_input[i]))
+ .raw();
+ float_gen_output[i] = LogPositiveValuesViaFloat(
+ fudged_input[i], InputIntegerBits, OutputIntegerBits);
+ }
+ // Note that first check is intolerant.
+ {
+ std::ostringstream label;
+ label << check_label << " / optimized vs reference / InputIntegerBits="
+ << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
+ CheckOutputData(
+ optimized_output, reference_output, test_input, label.str(),
+ InputIntegerBits, OutputIntegerBits, 0);
+ }
+ {
+ std::ostringstream label;
+ label << check_label << " / reference vs float-gen / InputIntegerBits="
+ << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
+ CheckOutputData(
+ reference_output, float_gen_output, test_input, label.str(),
+ InputIntegerBits, OutputIntegerBits, tolerance);
+ }
+ {
+ std::ostringstream label;
+ label << check_label << " optimized vs float-gen / InputIntegerBits="
+ << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
+ CheckOutputData(
+ optimized_output, float_gen_output, test_input, label.str(),
+ InputIntegerBits, OutputIntegerBits, tolerance);
+ }
+}
+
+template <int OutputIntegerBits>
+void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
+ const string& check_label, int tolerance) {
+#define INPUT_CASE(K) \
+ case K: \
+ return RunSingleTest<OutputIntegerBits, K>(test_input, check_label, \
+ tolerance)
+ switch (input_integer_bits) {
+ INPUT_CASE(0);
+ INPUT_CASE(1);
+ INPUT_CASE(2);
+ INPUT_CASE(3);
+ INPUT_CASE(4);
+ INPUT_CASE(5);
+ INPUT_CASE(6);
+ INPUT_CASE(7);
+ INPUT_CASE(8);
+ INPUT_CASE(9);
+ INPUT_CASE(10);
+ INPUT_CASE(11);
+ INPUT_CASE(12);
+ INPUT_CASE(13);
+ INPUT_CASE(14);
+ INPUT_CASE(15);
+ INPUT_CASE(16);
+ INPUT_CASE(17);
+ INPUT_CASE(18);
+ INPUT_CASE(19);
+ INPUT_CASE(20);
+ INPUT_CASE(21);
+ INPUT_CASE(22);
+ INPUT_CASE(23);
+ INPUT_CASE(24);
+ INPUT_CASE(25);
+ INPUT_CASE(26);
+ INPUT_CASE(27);
+ INPUT_CASE(28);
+ INPUT_CASE(29);
+ default:
+ ASSERT_LE(input_integer_bits, 30)
+ << "Input integer bits not handled: " << input_integer_bits;
+ }
+#undef INPUT_CASE
+}
+
+void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
+ int output_integer_bits, const string& check_label,
+ int tolerance) {
+#define OUTPUT_CASE(K) \
+ case K: \
+ return RunSingleTest<K>(test_input, input_integer_bits, check_label, \
+ tolerance)
+ switch (output_integer_bits) {
+ OUTPUT_CASE(0);
+ OUTPUT_CASE(1);
+ OUTPUT_CASE(2);
+ OUTPUT_CASE(3);
+ OUTPUT_CASE(4);
+ OUTPUT_CASE(5);
+ OUTPUT_CASE(6);
+ OUTPUT_CASE(7);
+ OUTPUT_CASE(8);
+ OUTPUT_CASE(9);
+ OUTPUT_CASE(10);
+ OUTPUT_CASE(11);
+ OUTPUT_CASE(12);
+ OUTPUT_CASE(13);
+ OUTPUT_CASE(14);
+ OUTPUT_CASE(15);
+ OUTPUT_CASE(16);
+ OUTPUT_CASE(17);
+ OUTPUT_CASE(18);
+ OUTPUT_CASE(19);
+ OUTPUT_CASE(20);
+ OUTPUT_CASE(21);
+ OUTPUT_CASE(22);
+ OUTPUT_CASE(23);
+ OUTPUT_CASE(24);
+ OUTPUT_CASE(25);
+ OUTPUT_CASE(26);
+ OUTPUT_CASE(27);
+ OUTPUT_CASE(28);
+ OUTPUT_CASE(29);
+ default:
+ ASSERT_LE(input_integer_bits, 30)
+ << "Input integer bits not handled: " << input_integer_bits;
+ }
+#undef OUTPUT_CASE
+}
+
+void RunUniformTest(int test_size, int input_integer_bits,
+ int output_integer_bits, const string& check_label,
+ int tolerance, NumberGenerator* generator) {
+ std::vector<int> test_data = generator->RandomIntVector(
+ test_size, 2, std::numeric_limits<int>::max() - 1);
+ test_data[0] = 2;
+ test_data[1] = 3;
+ test_data[2] = 4;
+ test_data[3] = std::numeric_limits<int32>::max() - 2;
+ test_data[4] = std::numeric_limits<int32>::max() - 1;
+ test_data[5] = std::numeric_limits<int32>::max();
+
+ RunSingleTest(test_data, input_integer_bits, output_integer_bits,
+ check_label + " / uniform test", tolerance);
+}
+
+void RunUniformShiftUniformTest(int test_size, int input_integer_bits,
+ int output_integer_bits,
+ const string& check_label, int tolerance,
+ NumberGenerator* generator) {
+ std::vector<int> test_data = generator->RandomIntVector(
+ test_size, 2, std::numeric_limits<int>::max() - 1);
+ std::vector<int> shifts = generator->RandomIntVector(test_size, 0, 29);
+ RightShiftVector(shifts, &test_data);
+
+ RunSingleTest(test_data, input_integer_bits, output_integer_bits,
+ check_label + " / shifted test", tolerance);
+}
+
+TEST_F(LogQuantizedTest, VariedIntegerBits) {
+ static constexpr int kVariations = 250;
+ static constexpr int kRunSize = 250;
+ static constexpr int kIntegerTolerance = 8;
+ static constexpr double kOutputFloatTolerance = 7.0e-7;
+
+ std::vector<int> input_integer_bits =
+ generator_.RandomIntVector(kVariations, 0, 24);
+ std::vector<int> output_integer_bits =
+ generator_.RandomIntVector(kVariations, 1, 10);
+
+ for (int i = 0; i < kVariations; ++i) {
+ int var_output_integer_bits = output_integer_bits[i];
+ int tolerance =
+ std::max(1.0 * kIntegerTolerance,
+ (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance);
+
+ RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits,
+ "VariedIntegerBits", tolerance, &generator_);
+ RunUniformShiftUniformTest(kRunSize, input_integer_bits[i],
+ var_output_integer_bits, "VariedIntegerBits",
+ tolerance, &generator_);
+ }
+}
+
+TEST_F(LogQuantizedTest, SelectedIntegerBits) {
+ static constexpr int kInputBits = 12;
+ static constexpr int kOutputBits = 5;
+ static constexpr int kRunSize = 100000;
+ static constexpr int kIntegerTolerance = 4;
+
+ RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits",
+ kIntegerTolerance, &generator_);
+ RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits,
+ "SelectedIntegerBits", kIntegerTolerance,
+ &generator_);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
new file mode 100644
index 0000000000..2252ca1bcc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -0,0 +1,251 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace {
+
+void RunLogSoftmaxFloatReference(const uint8* input_data,
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
+ uint8* reference_output_data) {
+ const int ref_buffer_size = shape_common.FlatSize();
+ std::vector<float> reference_dequant_data(ref_buffer_size);
+ std::vector<float> reference_output_float_data(ref_buffer_size);
+
+ // Reference data generated via Dequant of input into float, and then applying
+ // float LogSoftmax.
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ optimized_ops::LogSoftmax(sm_params, shape_common,
+ reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data());
+ // Work with quantized scaling for LogSoftmax, under which 255 represents 0,
+ // and -16 gets nudged up to 0.
+ for (int i = 0; i < ref_buffer_size; i++) {
+ reference_output_data[i] = std::max(
+ 0, static_cast<int>(
+ 255 + std::round(16.0f * reference_output_float_data[i])));
+ }
+}
+
+void CheckOutputData(const uint8* test_output, const uint8* reference_output,
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
+ // While calculating some metrics in floating point, we work with quantized
+ // scaling.
+ std::vector<int> diff(buffer_size);
+ int64_t sum_diff = 0;
+ int64_t sum_abs_diff = 0;
+ for (int i = 0; i < buffer_size; i++) {
+ diff[i] = static_cast<int>(test_output[i]) - reference_output[i];
+ sum_diff += diff[i];
+ sum_abs_diff += std::abs(diff[i]);
+ }
+ // These stats help understand test failures.
+ std::sort(std::begin(diff), std::end(diff));
+ const int min_diff = diff.front();
+ const int max_diff = diff.back();
+ const int median_diff = diff[diff.size() / 2];
+ const float mean_diff = static_cast<float>(sum_diff) / buffer_size;
+ const float mean_abs_diff = static_cast<float>(sum_abs_diff) / buffer_size;
+ // We either check for bit exactness (against the reference quantized version)
+ // or for general accuracy, allowing off-by-one (against the float reference).
+ if (be_exacting) {
+ ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0)
+ << check_label << ": "
+ << "std::abs(min_diff)=" << std::abs(min_diff)
+ << ", std::abs(max_diff)=" << std::abs(max_diff);
+ } else {
+ // For small numbers of samples, the estimates of the means vary more.
+ // Rather than widen the tolerances, we skip the smaller tests.
+ ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) ||
+ buffer_size < 10000) &&
+ std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
+ std::abs(max_diff) <= 1)
+ << check_label << ": "
+ << "buffer_size=" << buffer_size << ", mean_diff=" << mean_diff
+ << ", mean_abs_diff=" << mean_abs_diff
+ << ", median_diff=" << median_diff << ", min_diff=" << min_diff
+ << ", max_diff=" << max_diff;
+ }
+}
+
+// Runs the LogSoftmax and compares against the float reference implementation
+// and the quantized reference implementation.
+void RunOneLogSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
+ std::vector<uint8> optimized_logsoftmax_output(buffer_size);
+ std::vector<uint8> reference_float_logsoftmax_output(buffer_size);
+ std::vector<uint8> reference_quant_logsoftmax_output(buffer_size);
+
+ RunLogSoftmaxFloatReference(input_data, shape_common, input_offset,
+ input_scale, stride, beta,
+ reference_float_logsoftmax_output.data());
+
+ int32 input_beta_multiplier;
+ int input_beta_left_shift;
+ int32 reverse_scaling_divisor;
+ int reverse_scaling_right_shift;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessLogSoftmaxScalingExp(
+ beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier,
+ &input_beta_left_shift, &reverse_scaling_divisor,
+ &reverse_scaling_right_shift);
+ reverse_scaling_right_shift *= -1;
+ // diff_min has a negative value, and is used to limit the maximum magnitude
+ // of the diffs, which are <= 0.
+ const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
+ input_beta_left_shift);
+
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ optimized_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ optimized_logsoftmax_output.data());
+ reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ reference_quant_logsoftmax_output.data());
+
+ CheckOutputData(optimized_logsoftmax_output.data(),
+ reference_float_logsoftmax_output.data(), shape_common,
+ "Optimized vs float reference", false);
+ CheckOutputData(optimized_logsoftmax_output.data(),
+ reference_quant_logsoftmax_output.data(), shape_common,
+ "Optimized vs quant reference", true);
+ CheckOutputData(reference_quant_logsoftmax_output.data(),
+ reference_float_logsoftmax_output.data(), shape_common,
+ "Quant reference vs float reference", false);
+}
+
+// This function picks some random LogSoftmax params, which are checked for
+// desirability. If not acceptable, it returns false. If they're OK,
+// it runs the LogSoftmax test and returns true. This allows the caller
+// to loop until a test has been run.
+//
+// Currently we do not reject for any reason.
+bool TryOneUniformLogSoftmax() {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // LogSoftmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ static constexpr float beta = 1.0f;
+
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandom(&input_data);
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
+ input_scale, stride, beta);
+ return true;
+}
+
+// See TryOneUniformLogSoftmax() for a general description.
+//
+// Tests with "skyscraper" input patterns are included for two reasons. (a)
+// Bimodal distributions are potentially challenging and perhaps more
+// realistic than simple uniform random inputs. (b) Some implementations of
+// LogSoftmax may adapt as they traverse the depth, and so we test handling of
+// cases where relatively small values are encountered at the beginning and end.
+bool TryOneSkyscraperLogSoftmax(bool small_depth) {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // LogSoftmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = small_depth
+ ? ExponentialRandomPositiveInt(0.75f, 40, 500)
+ : ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ static constexpr float beta = 1.0f;
+ // Extra parameters for skyscraper input patterns.
+ const double middle_proportion =
+ ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0);
+ const int middle_min = UniformRandomInt(0, 255);
+ const int sides_max = UniformRandomInt(0, middle_min);
+
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
+ sides_max);
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
+ input_scale, stride, beta);
+ return true;
+}
+
+TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneUniformLogSoftmax()) {
+ }
+ }
+}
+
+TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperLogSoftmax(false)) {
+ }
+ }
+}
+
+TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperLogSoftmax(true)) {
+ }
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
index 4a90e7e640..2d96da65c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
@@ -31,33 +31,50 @@ limitations under the License.
namespace tflite {
namespace cblas_ops {
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
gemmlowp::ScopedProfilingLabel label("Conv/cblas");
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- optimized_ops::Im2col(input_data, input_dims, stride_width, stride_height,
- pad_width, pad_height, filter_height, filter_width, 0,
- im2col_data, im2col_dims);
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ optimized_ops::Im2col(op_params, filter_height, filter_width, 0,
+ input_shape, input_data, im2col_shape, im2col_data);
+
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
// The following code computes matrix multiplication c = a * transponse(b)
@@ -69,10 +86,10 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* a = gemm_input_data;
const float* b = filter_data;
float* c = output_data;
- int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- int n = output_dims.sizes[0];
- int k = gemm_input_dims->sizes[0];
+ const int gemm_input_dims = gemm_input_shape->DimensionsCount();
+ int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
+ int n = output_shape.Dims(3);
+ int k = gemm_input_shape->Dims(gemm_input_dims - 1);
// The stride of matrix a, b and c respectively.
int stride_a = k;
int stride_b = k;
@@ -82,8 +99,8 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
stride_a, b, stride_b, 0.0f, c, stride_c);
optimized_ops::AddBiasAndEvalActivationFunction(
- bias_data, bias_dims, output_data, output_dims, output_activation_min,
- output_activation_max);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace cblas_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
index 3a53d3ab07..934308ef29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
namespace tflite {
@@ -58,4 +58,4 @@ inline bool TestCPUFeatureNeon() { return false; }
: Portable##funcname(__VA_ARGS__)
#endif
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 7f6eea2d5d..d8dd7bba89 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -761,7 +761,8 @@ struct FloatDepthwiseConvKernel<true, 4, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
const float* input_data, int pad_width,
int depth_multiplier, int filter_width,
const float* filter_data,
@@ -835,10 +836,10 @@ void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
inline void FloatDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const float* input_data,
- int pad_width, int depth_multiplier, int filter_width,
- const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
- int output_depth, float* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const float* input_data, int pad_width, int depth_multiplier,
+ int filter_width, const float* filter_data, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, float* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -860,6 +861,7 @@ inline void FloatDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -869,14 +871,17 @@ inline void FloatDepthwiseConvAccumRowGeneric(
const float* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
float* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const float* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -907,25 +912,37 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
}
}
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
static const int kAccBufferMaxSize = 2048;
float acc_buffer[kAccBufferMaxSize];
@@ -946,7 +963,8 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && \
+ dilation_height_factor == 1 && dilation_width_factor == 1) { \
row_accum_func = \
FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -990,14 +1008,22 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
row_accum_func = FloatDepthwiseConvAccumRowGeneric;
}
+ const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+ const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+ const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
// Now that we have determined row_accum_func, we can start work.
float* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1013,14 +1039,13 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
- row_accum_func(stride_width, input_depth, input_width,
- input_data + in_y * input_dims.strides[2] +
- b * input_dims.strides[3],
- pad_width, depth_multiplier, filter_width,
- filter_data + filter_y * filter_dims.strides[2],
- out_x_buffer_start, out_x_buffer_end, output_depth,
- acc_buffer);
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ row_accum_func(
+ stride_width, dilation_width_factor, input_depth, input_width,
+ input_data + in_y * input_height_stride + b * input_batch_stride,
+ pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_height_stride, out_x_buffer_start,
+ out_x_buffer_end, output_depth, acc_buffer);
}
// Finished accumulating. Now store to destination.
const int num_output_values = output_depth * num_output_pixels;
@@ -1067,34 +1092,6 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- float* output_data, const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, pad_width, pad_height,
- depth_multiplier, output_data, output_dims);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index dd6932ffe7..803eff292a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1466,11 +1466,14 @@ struct QuantizedDepthwiseConvKernel<false, 12, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void QuantizedDepthwiseConvAccumRow(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset,
+ int pad_width, int depth_multiplier,
+ int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth,
+ int32* acc_buffer) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
#endif
@@ -1537,10 +1540,11 @@ void QuantizedDepthwiseConvAccumRow(
// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
inline void QuantizedDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset, int pad_width,
+ int depth_multiplier, int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, int32* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -1562,6 +1566,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -1571,14 +1576,17 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
const uint8* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
int32* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const uint8* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -1669,42 +1677,61 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
}
}
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_GE(dilation_width_factor, 1);
+ TFLITE_DCHECK_GE(dilation_height_factor, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+#ifdef USE_NEON
+ const bool shift_left = (output_shift > 0);
+ const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
+#endif
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
-
-#ifdef __aarch64__
+// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
+// Jetson TX-2. This compiler does not support the offsetof() macro.
+#if defined(__aarch64__) && !defined(GOOGLE_L4T)
// Call kernel optimized for depthwise convolutions using 3x3 filters if
// parameters are supported.
- if (Fast3x3FilterKernelSupported(input_dims, filter_dims, stride_width,
- stride_height, pad_width, pad_height,
- depth_multiplier, output_dims)) {
- DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims,
- stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
+ if (Fast3x3FilterKernelSupported(
+ input_shape, filter_shape, stride_width, stride_height,
+ dilation_width_factor, dilation_height_factor, pad_width, pad_height,
+ depth_multiplier, output_shape, output_shift)) {
+ DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data);
return;
}
#endif
@@ -1728,7 +1755,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && \
+ dilation_width_factor == 1 && dilation_height_factor == 1) { \
row_accum_func = \
QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -1779,14 +1807,22 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+ const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+ const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+ const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
// Now that we have determined row_accum_func, we can start work.
uint8* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1802,13 +1838,12 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
- stride_width, input_depth, input_width,
- input_data + in_y * input_dims.strides[2] +
- b * input_dims.strides[3],
+ stride_width, dilation_width_factor, input_depth, input_width,
+ input_data + in_y * input_height_stride + b * input_batch_stride,
input_offset, pad_width, depth_multiplier, filter_width,
- filter_data + filter_y * filter_dims.strides[2], filter_offset,
+ filter_data + filter_y * filter_height_stride, filter_offset,
out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
}
// Finished accumulating int32 values. Now need to convert them to
@@ -1833,12 +1868,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
}
- // Fixed-point multiplication.
- for (int j = 0; j < 4; j++) {
- acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
- }
- for (int j = 0; j < 4; j++) {
- acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ if (!shift_left) {
+ // Fixed-point multiplication.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
+ }
+ for (int j = 0; j < 4; j++) {
+ acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
+ }
+ } else {
+ // Fixed-point multiplication.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two);
+ acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
+ }
}
// Add the output offset.
for (int j = 0; j < 4; j++) {
@@ -1870,12 +1913,21 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
for (; i <= num_output_values - 8; i += 8) {
int32x4_t acc0 = vld1q_s32(acc_buffer + i);
int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
- // Fixed-point multiplication.
- acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
- acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
- // Rounding right shift.
- acc0 = RoundingDivideByPOT(acc0, output_shift);
- acc1 = RoundingDivideByPOT(acc1, output_shift);
+ if (!shift_left) {
+ // Fixed-point multiplication.
+ acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
+ acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
+ // Rounding right shift.
+ acc0 = RoundingDivideByPOT(acc0, -output_shift);
+ acc1 = RoundingDivideByPOT(acc1, -output_shift);
+ } else {
+ // Fixed-point multiplication.
+ acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
+ acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
+
+ acc1 = vmulq_n_s32(acc1, multiplier_power_of_two);
+ acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
+ }
// Add the output offset.
acc0 = vaddq_s32(acc0, output_offset_vec);
acc1 = vaddq_s32(acc1, output_offset_vec);
@@ -1899,10 +1951,16 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// that will have to go through the very slow scalar code.
for (; i <= num_output_values - 4; i += 4) {
int32x4_t acc = vld1q_s32(acc_buffer + i);
- // Fixed-point multiplication.
- acc = vqrdmulhq_n_s32(acc, output_multiplier);
- // Rounding right shift.
- acc = RoundingDivideByPOT(acc, output_shift);
+ if (!shift_left) {
+ // Fixed-point multiplication.
+ acc = vqrdmulhq_n_s32(acc, output_multiplier);
+ // Rounding right shift.
+ acc = RoundingDivideByPOT(acc, -output_shift);
+ } else {
+ // Fixed-point multiplication.
+ acc = vmulq_n_s32(acc, multiplier_power_of_two);
+ acc = vqrdmulhq_n_s32(acc, output_multiplier);
+ }
// Add the output offset.
acc = vaddq_s32(acc, output_offset_vec);
// Apply the activation function.
@@ -1923,8 +1981,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Handle leftover values, one by one. This is very slow.
for (; i < num_output_values; i++) {
int32 acc = acc_buffer[i];
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(
- acc, output_multiplier, output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -1935,48 +1993,6 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims, stride,
- stride, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 55e0d5c3aa..4809ddd02a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -23,3848 +23,2912 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
-#ifdef __aarch64__
-
-inline void preload_l1_keep(const uint8* ptr) {
-#ifdef GEMMLOWP_ARM_64
- asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
-#else
- gemmlowp::Prefetch(ptr);
-#endif
-}
-
-// Implementation of quantized DepthwiseConv for 3x3 filters.
-
-// Below are helper structs to remove the use of arrays.
-// There is an llvm bug that causes significant slowdown when using arrays for
-// NEON intrinsics vector data types.
-// See: https://bugs.llvm.org/show_bug.cgi?id=34945
-
-struct Int32x8 {
- int32x4_t low, high;
-};
-
-struct Filter3x3x8 {
- int16x8_t f0, f1, f2, f3, f4, f5, f6, f7, f8;
-};
-
-// Loads 3x3 filter of depth 8 and adds filter offsets.
-inline Filter3x3x8 Load3x3Filter(const uint8* filter_ptr, int32 filter_offset,
- int output_depth) {
- Filter3x3x8 filter;
-
- uint8x8_t temp_u8_0, temp_u8_1, temp_u8_2, temp_u8_3, temp_u8_4, temp_u8_5,
- temp_u8_6, temp_u8_7, temp_u8_8;
- int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
-
- temp_u8_0 = vld1_u8(filter_ptr + 0 * output_depth);
- temp_u8_1 = vld1_u8(filter_ptr + 1 * output_depth);
- temp_u8_2 = vld1_u8(filter_ptr + 2 * output_depth);
- temp_u8_3 = vld1_u8(filter_ptr + 3 * output_depth);
- temp_u8_4 = vld1_u8(filter_ptr + 4 * output_depth);
- temp_u8_5 = vld1_u8(filter_ptr + 5 * output_depth);
- temp_u8_6 = vld1_u8(filter_ptr + 6 * output_depth);
- temp_u8_7 = vld1_u8(filter_ptr + 7 * output_depth);
- temp_u8_8 = vld1_u8(filter_ptr + 8 * output_depth);
-
- filter.f0 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_0));
- filter.f1 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_1));
- filter.f2 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_2));
- filter.f3 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_3));
- filter.f4 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_4));
- filter.f5 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_5));
- filter.f6 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_6));
- filter.f7 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_7));
- filter.f8 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_8));
-
- filter.f0 = vaddq_s16(filter.f0, filter_offset_vec);
- filter.f1 = vaddq_s16(filter.f1, filter_offset_vec);
- filter.f2 = vaddq_s16(filter.f2, filter_offset_vec);
- filter.f3 = vaddq_s16(filter.f3, filter_offset_vec);
- filter.f4 = vaddq_s16(filter.f4, filter_offset_vec);
- filter.f5 = vaddq_s16(filter.f5, filter_offset_vec);
- filter.f6 = vaddq_s16(filter.f6, filter_offset_vec);
- filter.f7 = vaddq_s16(filter.f7, filter_offset_vec);
- filter.f8 = vaddq_s16(filter.f8, filter_offset_vec);
-
- return filter;
-}
-
-// Applies activation, offset and downquantize on a set of accumulator
-// registers that correspond to a 2x2 output of depth 8.
-// Stores results to output.
-inline void DownquantizeAndStore2x2Output(
- Int32x8 acc_0, Int32x8 acc_1, Int32x8 acc_2, Int32x8 acc_3,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- using gemmlowp::RoundingDivideByPOT;
- const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
- const int32x4_t output_activation_min_vec =
- vdupq_n_s32(output_activation_min);
- const int32x4_t output_activation_max_vec =
- vdupq_n_s32(output_activation_max);
-
- // Fixed-point multiplication.
- acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier);
- acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier);
- acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier);
- acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier);
- acc_2.low = vqrdmulhq_n_s32(acc_2.low, output_multiplier);
- acc_2.high = vqrdmulhq_n_s32(acc_2.high, output_multiplier);
- acc_3.low = vqrdmulhq_n_s32(acc_3.low, output_multiplier);
- acc_3.high = vqrdmulhq_n_s32(acc_3.high, output_multiplier);
-
- acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift);
- acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift);
- acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift);
- acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift);
- acc_2.low = RoundingDivideByPOT(acc_2.low, output_shift);
- acc_2.high = RoundingDivideByPOT(acc_2.high, output_shift);
- acc_3.low = RoundingDivideByPOT(acc_3.low, output_shift);
- acc_3.high = RoundingDivideByPOT(acc_3.high, output_shift);
-
- // Add the output offset.
- acc_0.low = vaddq_s32(acc_0.low, output_offset_vec);
- acc_0.high = vaddq_s32(acc_0.high, output_offset_vec);
- acc_1.low = vaddq_s32(acc_1.low, output_offset_vec);
- acc_1.high = vaddq_s32(acc_1.high, output_offset_vec);
- acc_2.low = vaddq_s32(acc_2.low, output_offset_vec);
- acc_2.high = vaddq_s32(acc_2.high, output_offset_vec);
- acc_3.low = vaddq_s32(acc_3.low, output_offset_vec);
- acc_3.high = vaddq_s32(acc_3.high, output_offset_vec);
-
- // Apply the activation function.
- acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec);
- acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec);
- acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec);
- acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec);
- acc_2.low = vmaxq_s32(acc_2.low, output_activation_min_vec);
- acc_2.high = vmaxq_s32(acc_2.high, output_activation_min_vec);
- acc_3.low = vmaxq_s32(acc_3.low, output_activation_min_vec);
- acc_3.high = vmaxq_s32(acc_3.high, output_activation_min_vec);
-
- acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec);
- acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec);
- acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec);
- acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec);
- acc_2.low = vminq_s32(acc_2.low, output_activation_max_vec);
- acc_2.high = vminq_s32(acc_2.high, output_activation_max_vec);
- acc_3.low = vminq_s32(acc_3.low, output_activation_max_vec);
- acc_3.high = vminq_s32(acc_3.high, output_activation_max_vec);
-
- // Saturating cast to uint8 and store to destination.
- int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low);
- int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high);
- int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low);
- int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high);
- int16x4_t acc_2_low_s16 = vqmovn_s32(acc_2.low);
- int16x4_t acc_2_high_s16 = vqmovn_s32(acc_2.high);
- int16x4_t acc_3_low_s16 = vqmovn_s32(acc_3.low);
- int16x4_t acc_3_high_s16 = vqmovn_s32(acc_3.high);
-
- int16x8_t res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16);
- int16x8_t res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16);
- int16x8_t res_2_s16 = vcombine_s16(acc_2_low_s16, acc_2_high_s16);
- int16x8_t res_3_s16 = vcombine_s16(acc_3_low_s16, acc_3_high_s16);
-
- uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16);
- uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16);
- uint8x8_t res_2_u8 = vqmovun_s16(res_2_s16);
- uint8x8_t res_3_u8 = vqmovun_s16(res_3_s16);
-
- vst1_u8(output_ptr, res_0_u8);
- vst1_u8(output_ptr + output_depth, res_1_u8);
- vst1_u8(output_ptr + output_depth * output_width, res_2_u8);
- vst1_u8(output_ptr + output_depth * output_width + output_depth, res_3_u8);
-}
-
-inline void DownquantizeAndStore(Int32x8 acc, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max,
- uint8* output_ptr) {
- using gemmlowp::RoundingDivideByPOT;
- const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
- const int32x4_t output_activation_min_vec =
- vdupq_n_s32(output_activation_min);
- const int32x4_t output_activation_max_vec =
- vdupq_n_s32(output_activation_max);
-
- acc.low = vqrdmulhq_n_s32(acc.low, output_multiplier);
- acc.high = vqrdmulhq_n_s32(acc.high, output_multiplier);
-
- acc.low = RoundingDivideByPOT(acc.low, output_shift);
- acc.high = RoundingDivideByPOT(acc.high, output_shift);
-
- acc.low = vaddq_s32(acc.low, output_offset_vec);
- acc.high = vaddq_s32(acc.high, output_offset_vec);
-
- acc.low = vmaxq_s32(acc.low, output_activation_min_vec);
- acc.high = vmaxq_s32(acc.high, output_activation_min_vec);
-
- acc.low = vminq_s32(acc.low, output_activation_max_vec);
- acc.high = vminq_s32(acc.high, output_activation_max_vec);
-
- int16x4_t acc_low_s16 = vqmovn_s32(acc.low);
- int16x4_t acc_high_s16 = vqmovn_s32(acc.high);
-
- int16x8_t res_s16 = vcombine_s16(acc_low_s16, acc_high_s16);
- uint8x8_t res_u8 = vqmovun_s16(res_s16);
- vst1_u8(output_ptr, res_u8);
-}
+// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
+// Jetson TX-2. This compiler does not support the offsetof() macro.
+#if defined(__aarch64__) && !defined(GOOGLE_L4T)
+#include <stddef.h>
+// clang-format gets confused with this file and ends up formatting lines to
+// be larger than 80 characters. Turn off here and back on at the end of the
+// file.
-inline void DownquantizeAndStore2Output(
- Int32x8 acc_0, Int32x8 acc_1, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_ptr, int output_ptr_offset) {
- {
- using gemmlowp::RoundingDivideByPOT;
- const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
- const int32x4_t output_activation_min_vec =
- vdupq_n_s32(output_activation_min);
- const int32x4_t output_activation_max_vec =
- vdupq_n_s32(output_activation_max);
-
- // Fixed-point multiplication.
- acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier);
- acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier);
- acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier);
- acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier);
-
- acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift);
- acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift);
- acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift);
- acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift);
-
- // Add the output offset.
- acc_0.low = vaddq_s32(acc_0.low, output_offset_vec);
- acc_0.high = vaddq_s32(acc_0.high, output_offset_vec);
- acc_1.low = vaddq_s32(acc_1.low, output_offset_vec);
- acc_1.high = vaddq_s32(acc_1.high, output_offset_vec);
-
- // Apply the activation function.
- acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec);
- acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec);
- acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec);
- acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec);
-
- acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec);
- acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec);
- acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec);
- acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec);
- }
-
- // Saturating cast to uint8 and store to destination.
- int16x8_t res_0_s16;
- {
- int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low);
- int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high);
- res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16);
- }
-
- int16x8_t res_1_s16;
- {
- int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low);
- int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high);
- res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16);
- }
-
- uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16);
- uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16);
- vst1_u8(output_ptr, res_0_u8);
- vst1_u8(output_ptr + output_ptr_offset, res_1_u8);
-}
-
-// Performs multiply accumulate on 3 inputs of depth 8.
-inline Int32x8 MultiplyAccumulateRow(Int32x8 accum, int16x8_t f0, int16x8_t f1,
- int16x8_t f2, int16x8_t i0, int16x8_t i1,
- int16x8_t i2) {
- accum.low = vmlal_s16(accum.low, vget_low_s16(f0), vget_low_s16(i0));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f0), vget_high_s16(i0));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f1), vget_low_s16(i1));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f1), vget_high_s16(i1));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f2), vget_low_s16(i2));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f2), vget_high_s16(i2));
- return accum;
-}
-
-// Performs multiply accumulate on 3 inputs of depth 8.
-inline Int32x8 MultiplyAccumulate3x3Filter(const Filter3x3x8& f, int16x8_t i0,
- int16x8_t i1, int16x8_t i2,
- int16x8_t i3, int16x8_t i4,
- int16x8_t i5, int16x8_t i6,
- int16x8_t i7, int16x8_t i8,
- Int32x8 accum) {
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f0), vget_low_s16(i0));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f0), vget_high_s16(i0));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f1), vget_low_s16(i1));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f1), vget_high_s16(i1));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f2), vget_low_s16(i2));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f2), vget_high_s16(i2));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f3), vget_low_s16(i3));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f3), vget_high_s16(i3));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f4), vget_low_s16(i4));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f4), vget_high_s16(i4));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f5), vget_low_s16(i5));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f5), vget_high_s16(i5));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f6), vget_low_s16(i6));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f6), vget_high_s16(i6));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f7), vget_low_s16(i7));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f7), vget_high_s16(i7));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f8), vget_low_s16(i8));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f8), vget_high_s16(i8));
- return accum;
-}
-
-inline void DotProductAndStore(const Filter3x3x8& filter, int16x8_t i0,
- int16x8_t i1, int16x8_t i2, int16x8_t i3,
- int16x8_t i4, int16x8_t i5, int16x8_t i6,
- int16x8_t i7, int16x8_t i8,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr) {
- Int32x8 acc;
- acc.low = vld1q_s32(bias_ptr);
- acc.high = vld1q_s32(bias_ptr + 4);
-
- acc = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, i8,
- acc);
-
- DownquantizeAndStore(acc, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- output_ptr);
-}
-
-// Performs multiply-accumulate on a 3x4 input for 2 horizontal outputs.
-inline void DotProductAndStore2xStride1(
- const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2,
- int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7,
- int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11,
- const int32* bias_ptr, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_ptr, int output_ptr_offset) {
- Int32x8 acc_0, acc_1;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i4, i5, i6, i8, i9,
- i10, acc_0);
- acc_1 = MultiplyAccumulate3x3Filter(filter, i1, i2, i3, i5, i6, i7, i9, i10,
- i11, acc_1);
- DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_ptr,
- output_ptr_offset);
-}
-
-// Performs multiply-accumulate on a 4x3 input for 2 vertical outputs.
-inline void DotProductAndStore2yStride1(
- const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2,
- int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7,
- int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11,
- const int32* bias_ptr, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_ptr, int output_ptr_offset) {
- Int32x8 acc_0, acc_1;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7,
- i8, acc_0);
- acc_1 = MultiplyAccumulate3x3Filter(filter, i3, i4, i5, i6, i7, i8, i9, i10,
- i11, acc_1);
- DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_ptr,
- output_ptr_offset);
-}
-
-// A kernel that is optimized on the number of output cells in the x and y
-// direction, and the stride. Assumes 3x3 filters of 8 depth.
-template <int kFixedOutputY, int kFixedOutputX, int kFixedStrideWidth,
- int kFixedStrideHeight>
-struct ConvKernel3x3FilterDepth8 {};
-
-template <>
-struct ConvKernel3x3FilterDepth8<8, 8, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- // To process 8x8 outputs using a 3x3 filter, we require 10x10 inputs.
- // Load inputs for the first 2 filters on the top left, then slide to
- // the right, down, left, down, right, etc. in a snake-like path. This
- // minimizes the total number of loads.
- //
- // INPUT OUTPUT
- // |\----------------\ |\------------\
- // | \ \ | \ \
- // | \----------------\ | \------------\
- // | | 0 ... 9 | | | 0 ... 7 |
- // | | 10 ... 19 | ---> | | 8 ... 15 |
- // | | 20 ... 29 | \ | .. ... .. |
- // \ | .. ... .. | \| 56 ... 63 |
- // \| 90 ... 109 | |------------|
- // |----------------|
- //
- // The first set of loads corresponds to:
- //
- // INPUT OUTPUT
- // |\----------------- |\-----------
- // | \ | \
- // | \----------------- | \----------
- // | | 0 1 2 3 ... | | 0 1 ...
- // | | 10 11 12 13 ... ---> | | .. ...
- // | | 20 21 22 23 ... | .. ...
- // | | .. ... ...
- //
- // The next set of loads correspond to a sliding window to the right.
- // It loads inputs 4, 5, 14, 15, 23, 24 and keeps 2, 3, 12, 13, and 22:
- //
- // INPUT OUTPUT
- // |\------------------- |\-------------
- // | \ | \
- // | \------------------- | \------------
- // | | .. 2 3 4 5 ... | | .. 2 3 ...
- // | | .. 12 13 14 15 ... ---> | | .. ...
- // | | .. 21 22 23 24 ... | .. ...
- // | | .. ... ...
- //
- // And so on...
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left. Referring to the
- // indexes in the diagram above, this corresponds to outputs (0) and (1).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Slide to the right for outputs x = [2, 3], y = 0. Referring to the
- // indexes in the diagram above, this corresponds to outputs (2) and (3).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
-
- // Slide to the right again for outputs x = [4, 5], y = 0. Referring to the
- // indexes in the diagram above, this corresponds to outputs (4) and (5).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 4 * output_depth, output_depth);
-
- // Slide to the right one last time for outputs x = [6, 7], y = 0.
- // Referring to the indexes in the diagram above, this corresponds to
- // outputs (6) and (7).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 6 * output_depth, output_depth);
-
- // Slide to down for outputs x = [6, 7], y = 1. Referring to the indexes in
- // the diagram above, this corresponds to outputs (14) and (15).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 6 * output_depth + output_row_size,
- output_depth);
-
- // Slide left for outputs x = [4, 5], y = 1. Referring to the indexes in
- // the diagram above, this corresponds to outputs (12) and (13).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 4 * output_depth + output_row_size,
- output_depth);
-
- // Slide left again for outputs x = [2, 3], y = 1. Referring to the indexes
- // in the diagram above, this corresponds to outputs (10) and (11).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth + output_row_size,
- output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 1. Referring to the
- // indexes in the diagram above, this corresponds to outputs (8) and (9).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + output_row_size, output_depth);
-
- // Slide down for outputs x = [0, 1], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (16) and (17).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_row_size, output_depth);
-
- // Slide right for outputs x = [2, 3], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (18) and (19).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 2 * output_row_size, output_depth);
-
- // Slide right for outputs x = [4, 5], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (20) and (21).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 2 * output_row_size, output_depth);
-
- // Slide right one more time for outputs x = [6, 7], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (22) and (23).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 2 * output_row_size, output_depth);
-
- // Slide down for outputs x = [6, 7], y = 3. Referring to the indexes in
- // the diagram above, this corresponds to outputs (30) and (31).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 3 * output_row_size, output_depth);
-
- // Slide left for outputs x = [4, 5], y = 3. Referring to the indexes in
- // the diagram above, this corresponds to outputs (28) and (29).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 3 * output_row_size, output_depth);
-
- // Slide left for outputs x = [2, 3], y = 3. Referring to the indexes in
- // the diagram above, this corresponds to outputs (26) and (27).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 3 * output_row_size, output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 3. Referring to the
- // indexes in the diagram above, this corresponds to outputs (24) and (25).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 3 * output_row_size, output_depth);
-
- // Slide down for outputs x = [0, 1], y = 4. Referring to the indexes in
- // the diagram above, this corresponds to outputs (32) and (33).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 4 * output_row_size, output_depth);
-
- // Slide right for outputs x = [2, 3], y = 4. Referring to the indexes in
- // the diagram above, this corresponds to outputs (34) and (35).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 4 * output_row_size, output_depth);
-
- // Slide right for outputs x = [4, 5], y = 4. Referring to the indexes in
- // the diagram above, this corresponds to outputs (36) and (37).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 4 * output_row_size, output_depth);
-
- // Slide right one more time for outputs x = [6, 7], y = 4. Referring to the
- // indexes in the diagram above, this corresponds to outputs (38) and (39).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 4 * output_row_size, output_depth);
-
- // Slide down for outputs x = [6, 7], y = 5. Referring to the indexes in
- // the diagram above, this corresponds to outputs (46) and (47).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 5 * output_row_size, output_depth);
-
- // Slide left for outputs x = [4, 5], y = 5. Referring to the indexes in
- // the diagram above, this corresponds to outputs (44) and (45).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 5 * output_row_size, output_depth);
-
- // Slide left for outputs x = [2, 3], y = 5. Referring to the indexes in
- // the diagram above, this corresponds to outputs (42) and (43).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 5 * output_row_size, output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 5. Referring to the
- // indexes in the diagram above, this corresponds to outputs (40) and (41).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 5 * output_row_size, output_depth);
-
- // Slide down for outputs x = [0, 1], y = 6. Referring to the indexes in
- // the diagram above, this corresponds to outputs (48) and (49).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 8 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 6 * output_row_size, output_depth);
-
- // Slide right for outputs x = [2, 3], y = 6. Referring to the indexes in
- // the diagram above, this corresponds to outputs (50) and (51).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 6 * output_row_size, output_depth);
-
- // Slide right for outputs x = [4, 5], y = 6. Referring to the indexes in
- // the diagram above, this corresponds to outputs (52) and (53).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 6 * output_row_size, output_depth);
-
- // Slide right one more time for outputs x = [6, 7], y = 6. Referring to the
- // indexes in the diagram above, this corresponds to outputs (54) and (55).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 6 * output_row_size, output_depth);
-
- // Slide down for outputs x = [6, 7], y = 7. Referring to the indexes in the
- // diagram above, this corresponds to outputs (62) and (63).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 9 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 7 * output_row_size, output_depth);
-
- // Slide left for outputs x = [4, 5], y = 7. Referring to the indexes in the
- // diagram above, this corresponds to outputs (60) and (61).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 7 * output_row_size, output_depth);
-
- // Slide left for outputs x = [2, 3], y = 7. Referring to the indexes in the
- // diagram above, this corresponds to outputs (58) and (59).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 7 * output_row_size, output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 7. Referring to the
- // indexes in the diagram above, this corresponds to outputs (56) and (57).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 7 * output_row_size, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 4, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- // To process 4x4 outputs using a 3x3 filter, we require 6x6 inputs.
- // Load inputs for the first 2 filters on the top left, then slide to
- // the right, down, left, down, right, etc. in a snake-like path. This
- // minimizes the total number of loads.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Now load 1x2 inputs on the top right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth + output_row_size,
- output_depth);
-
- // Now load next inputs when sliding window left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + output_row_size, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_row_size, output_depth);
-
- // Now load next inputs when sliding window right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 2 * output_row_size, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 3 * output_row_size, output_depth);
-
- // Now load next inputs when sliding window left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 3 * output_row_size, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 2, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Now load next inputs one row down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Now load next row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Now load last row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 1, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 2x1 outputs starting from the top.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2yStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_row_size);
-
- // Load inputs for bottom 2 rows.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2yStride1(
- filter, input_6, input_7, input_8, input_9, input_10, input_11, input_0,
- input_1, input_2, input_3, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_row_size,
- output_row_size);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 2, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- Int32x8 acc_0, acc_1, acc_2, acc_3;
-
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_2.low = vld1q_s32(bias_ptr);
- acc_3.low = vld1q_s32(bias_ptr);
-
- bias_ptr += 4;
- acc_0.high = vld1q_s32(bias_ptr);
- acc_1.high = vld1q_s32(bias_ptr);
- acc_2.high = vld1q_s32(bias_ptr);
- acc_3.high = vld1q_s32(bias_ptr);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- // Add scope for input registers to help the compiler know that it is
- // not needed.
- {
- // To process 2x2 outputs using a 3x3 filter, we require 4x4 inputs.
- // Load inputs for the top two filters first.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- const uint8* ptr = input_ptr;
-
- // Load top 3 rows.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- // Multiply-accum for top-left output.
- acc_0 = MultiplyAccumulate3x3Filter(filter, input_0, input_1, input_2,
- input_4, input_5, input_6, input_8,
- input_9, input_10, acc_0);
-
- // Multiply-accum for top-right output.
- acc_1 = MultiplyAccumulate3x3Filter(filter, input_1, input_2, input_3,
- input_5, input_6, input_7, input_9,
- input_10, input_11, acc_1);
-
- // Now load the bottom row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
+// clang-format off
- // Multiply-accum for bottom-left output.
- acc_2 = MultiplyAccumulate3x3Filter(filter, input_4, input_5, input_6,
- input_8, input_9, input_10, input_0,
- input_1, input_2, acc_2);
-
- // Multiply-accum for bottom-right output.
- acc_3 = MultiplyAccumulate3x3Filter(filter, input_5, input_6, input_7,
- input_9, input_10, input_11, input_1,
- input_2, input_3, acc_3);
- }
-
- DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset,
- output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 4, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Now load 1x2 inputs on the top right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth + output_row_size,
- output_depth);
-
- // Now load next inputs when sliding window left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + output_row_size, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<1, 4, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Now load 1x2 inputs on the right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_depth * 4;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
+#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
- }
+// Encapsulates constant parameters used in DepthwiseConv.
+// 64-bit is used for types that will be added to 64-bit addresses in asm.
+struct DepthwiseConvParams {
+ int64_t input_depth;
+ int64_t input_row_size;
+ int64_t output_depth;
+ int64_t output_row_size;
+ int64_t filter_row_size;
+ int32 input_offset;
+ int32 output_offset;
+ int32 filter_offset;
+ int32 output_multiplier;
+ int32 output_activation_min;
+ int32 output_activation_max;
+ int32 output_right_shift;
+ int32 input_width;
+ int32 input_height;
+ int32 stride_width;
+ int32 stride_height;
+ int32 output_width;
+ int32 output_height;
};
-template <>
-struct ConvKernel3x3FilterDepth8<2, 1, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- // To process 2x1 outputs using a 3x3 filter, we require 4x3 inputs.
- // Load all inputs at the beginning.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left.
- {
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2yStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth * output_width);
- }
-};
+#define STR(s) STR_UNEXPANDED(s)
+#define STR_UNEXPANDED(s) #s
+
+// Represents the number of bytes offset from the start of the
+// DepthwiseConvParams struct. This is used in the asm to load parameters.
+// Keep these values in sync with the static_asserts below.
+#define OFFSET_INPUT_DEPTH 0
+#define OFFSET_INPUT_ROW_SIZE 8
+#define OFFSET_OUTPUT_DEPTH 16
+#define OFFSET_OUTPUT_ROW_SIZE 24
+#define OFFSET_FILTER_ROW_SIZE 32
+#define OFFSET_INPUT_OFFSET 40
+#define OFFSET_OUTPUT_OFFSET 44
+#define OFFSET_FILTER_OFFSET 48
+#define OFFSET_OUTPUT_MULTIPLIER 52
+#define OFFSET_OUTPUT_ACTIVATION_MIN 56
+#define OFFSET_OUTPUT_ACTIVATION_MAX 60
+#define OFFSET_OUTPUT_RIGHT_SHIFT 64
+#define OFFSET_INPUT_WIDTH 68
+#define OFFSET_INPUT_HEIGHT 72
+#define OFFSET_STRIDE_WIDTH 76
+#define OFFSET_STRIDE_HEIGHT 80
+#define OFFSET_OUTPUT_WIDTH 84
+#define OFFSET_OUTPUT_HEIGHT 88
+
+static_assert(offsetof(DepthwiseConvParams, input_depth) ==
+ OFFSET_INPUT_DEPTH, "");
+static_assert(offsetof(DepthwiseConvParams, input_row_size) ==
+ OFFSET_INPUT_ROW_SIZE, "");
+static_assert(offsetof(DepthwiseConvParams, output_depth) ==
+ OFFSET_OUTPUT_DEPTH, "");
+static_assert(offsetof(DepthwiseConvParams, output_row_size) ==
+ OFFSET_OUTPUT_ROW_SIZE, "");
+static_assert(offsetof(DepthwiseConvParams, filter_row_size) ==
+ OFFSET_FILTER_ROW_SIZE, "");
+static_assert(offsetof(DepthwiseConvParams, input_offset) ==
+ OFFSET_INPUT_OFFSET, "");
+static_assert(offsetof(DepthwiseConvParams, output_offset) ==
+ OFFSET_OUTPUT_OFFSET, "");
+static_assert(offsetof(DepthwiseConvParams, filter_offset) ==
+ OFFSET_FILTER_OFFSET, "");
+static_assert(offsetof(DepthwiseConvParams, output_multiplier) ==
+ OFFSET_OUTPUT_MULTIPLIER, "");
+static_assert(offsetof(DepthwiseConvParams, output_activation_min) ==
+ OFFSET_OUTPUT_ACTIVATION_MIN, "");
+static_assert(offsetof(DepthwiseConvParams, output_activation_max) ==
+ OFFSET_OUTPUT_ACTIVATION_MAX, "");
+static_assert(offsetof(DepthwiseConvParams, output_right_shift) ==
+ OFFSET_OUTPUT_RIGHT_SHIFT, "");
+static_assert(offsetof(DepthwiseConvParams, input_width) ==
+ OFFSET_INPUT_WIDTH, "");
+static_assert(offsetof(DepthwiseConvParams, input_height) ==
+ OFFSET_INPUT_HEIGHT, "");
+static_assert(offsetof(DepthwiseConvParams, stride_width) ==
+ OFFSET_STRIDE_WIDTH, "");
+static_assert(offsetof(DepthwiseConvParams, stride_height) ==
+ OFFSET_STRIDE_HEIGHT, "");
+static_assert(offsetof(DepthwiseConvParams, output_width) ==
+ OFFSET_OUTPUT_WIDTH, "");
+static_assert(offsetof(DepthwiseConvParams, output_height) ==
+ OFFSET_OUTPUT_HEIGHT, "");
+
+template <int32 kDepth, int32 kStrideWidth, int32 kStrideHeight>
+struct DepthwiseConvWindow {};
template <>
-struct ConvKernel3x3FilterDepth8<4, 2, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- const int output_row_size = output_depth * output_width;
-
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- Int32x8 acc_0, acc_1;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9;
-
- const uint8* ptr = input_ptr;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- // Load first 2 rows.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next 2 rows.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Moving onto the next row of outputs.
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next 2 rows.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Moving onto the next row of outputs.
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next 2 rows.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Moving onto the next row of outputs.
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load last row.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
+struct DepthwiseConvWindow<8, 1, 1> {
+ public:
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr, int64_t input_depth,
+ int64_t input_row_size, int32 output_window_height,
+ int32 output_window_width,
+ const DepthwiseConvParams* params_ptr) {
+ const int64_t input_width_increment = 2 * input_depth;
+ const int64_t input_height_increment = 2 * input_row_size;
+ const int64_t output_height_increment = 2 * params_ptr->output_row_size;
+
+#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6"
+#define DEPTHWISECONV_LABEL_HEIGHT_1 "7"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11"
+
+ asm volatile(
+ // Performs depthwise convolutions for a window specified by
+ // |output_window_height| and |output_window_width|. The inner-most loop
+ // processes 2x2 outputs, and any leftovers at the end.
+ //
+ // Algorithm works as follows:
+ //
+ // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter
+ // values.
+ // 2. For 2 output heights at a time:
+ // i. For 2 output widths at a time, load inputs for a 2x1 (2
+ // height, 1 width) output window (4x3 input window).
+ // Registers v9--v20 hold input values. Mul-add with
+ // accumulators v21--v24. Then run activation, downquantize
+ // and store. Repeat for the next 2x1 output window,
+ // leveraging overlapping inputs.
+ // ii. Handle single leftover width if exists.
+ // 3. Handle single leftover height if exists.
+ // i. For 2 output widths at a time, load inputs for a 1x2 (1
+ // height, 2 width) output window (3x4 input window).
+ // Registers v9--v20 hold input values. Mul-add with
+ // accumulators v21--v24. Then run activation, downquantize
+ // and store. Repeat for the next 1x2 output window,
+ // leveraging overlapping inputs.
+ // ii. Handle single leftover width if exists.
+ //
+ // Loads are placed as soon as the register is no longer needed and
+ // interleaved with arithmetic operations to take advantage of
+ // dual-issue pipelines. We also add input offsets as far from the loads
+ // as possible to give loads enough cycles to fetch data from memory.
+
+ // Set "constant" registers. These registers may be replaced with temp
+ // values from time to time when there are not enough NEON registers.
+ // We use x9--x15 general purpose registers as they are caller-saved
+ // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "cmp %w[output_window_height], #2\n"
+ "dup v26.8h, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
+ "dup v29.4s, w2\n"
+ "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "dup v30.4s, w4\n"
+ "ldr w0, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v31.4s, w0\n"
+ "neg w9, w9\n"
+ "dup v28.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "add x10, %[bias_ptr], #16\n"
+ "ldr x1, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n"
+ "dup v9.8h, w9\n"
+
+ // Load filters and add offsets.
+ "ld1 {v0.8b}, [%[filter_ptr]], x3\n"
+ "ld1 {v1.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v0.8h, v9.8h, v0.8b\n"
+ "ld1 {v2.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v1.8h, v9.8h, v1.8b\n"
+ "ld1 {v3.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v2.8h, v9.8h, v2.8b\n"
+ "ld1 {v4.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v3.8h, v9.8h, v3.8b\n"
+ "ld1 {v5.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v4.8h, v9.8h, v4.8b\n"
+ "ld1 {v6.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v5.8h, v9.8h, v5.8b\n"
+ "ld1 {v7.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v6.8h, v9.8h, v6.8b\n"
+ "ld1 {v8.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v7.8h, v9.8h, v7.8b\n"
+ "uaddw v8.8h, v9.8h, v8.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n"
+ // This loop processes 2x2 outputs. To avoid register exhaustion,
+ // inputs for the left 2 outputs are loaded first, then the right
+ // two outputs.
+ "mov x11, %[input_ptr]\n"
+ "mov x12, x11\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "add x13, x11, %[input_row_size]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "add x14, x13, %[input_row_size]\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x15, x14, %[input_row_size]\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "mov w5, %w[output_window_width]\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "add x7, %[output_ptr], x1\n"
+ "ld1 {v15.8b}, [x14], %[input_depth]\n"
+ // The height 2 / width 2 loop loads an extra 2x1 outputs (2 height,
+ // 1 width) in anticipation for the next iteration. Make sure
+ // |output_window_width| is large enough to handle the additional
+ // loads, otherwise jump to specific the appropriate label to handle
+ // smaller widths.
+ "cmp w5, #2\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v16.8b}, [x14], %[input_depth]\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "ld1 {v18.8b}, [x15], %[input_depth]\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "ld1 {v19.8b}, [x15], %[input_depth]\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "ld1 {v20.8b}, [x15], %[input_depth]\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n"
+ "cmp w5, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n"
+ // Mul-add left outputs.
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "subs w5, w5, #2\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "cmp w5, #3\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v14.4h\n"
+ "smlal2 v24.4s, v2.8h, v14.8h\n"
+ "smlal v21.4s, v3.4h, v12.4h\n"
+ "smlal2 v22.4s, v3.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13]\n"
+ "smlal v23.4s, v3.4h, v15.4h\n"
+ "smlal2 v24.4s, v3.8h, v15.8h\n"
+ "smlal v21.4s, v4.4h, v13.4h\n"
+ "smlal2 v22.4s, v4.8h, v13.8h\n"
+ "smlal v23.4s, v4.4h, v16.4h\n"
+ "smlal2 v24.4s, v4.8h, v16.8h\n"
+ "smlal v21.4s, v5.4h, v14.4h\n"
+ "smlal2 v22.4s, v5.8h, v14.8h\n"
+ "smlal v23.4s, v5.4h, v17.4h\n"
+ "smlal2 v24.4s, v5.8h, v17.8h\n"
+ "smlal v21.4s, v6.4h, v15.4h\n"
+ "smlal2 v22.4s, v6.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x14]\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x15]\n"
+ "smlal v21.4s, v7.4h, v16.4h\n"
+ "smlal2 v22.4s, v7.8h, v16.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+
+ // Mul-add right outputs.
+ "smlal v21.4s, v0.4h, v10.4h\n"
+ "add x11, x11, %[input_width_increment]\n"
+ "smlal2 v22.4s, v0.8h, v10.8h\n"
+ "mov x12, x11\n"
+ "smlal v23.4s, v0.4h, v13.4h\n"
+ "add x13, x11, %[input_row_size]\n"
+ "smlal2 v24.4s, v0.8h, v13.8h\n"
+ "add x14, x13, %[input_row_size]\n"
+ "smlal v21.4s, v1.4h, v11.4h\n"
+ "add x15, x14, %[input_row_size]\n"
+ "smlal2 v22.4s, v1.8h, v11.8h\n"
+ "smlal v23.4s, v1.4h, v14.4h\n"
+ "smlal2 v24.4s, v1.8h, v14.8h\n"
+ "smlal v21.4s, v2.4h, v9.4h\n"
+ "smlal2 v22.4s, v2.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "smlal v21.4s, v5.4h, v12.4h\n"
+ "smlal2 v22.4s, v5.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v15.4h\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v5.8h, v15.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v21.4s, v6.4h, v16.4h\n"
+ "smlal2 v22.4s, v6.8h, v16.8h\n"
+ "smlal v23.4s, v6.4h, v19.4h\n"
+ "smlal2 v24.4s, v6.8h, v19.8h\n"
+ "smlal v21.4s, v7.4h, v17.4h\n"
+ "smlal2 v22.4s, v7.8h, v17.8h\n"
+ "smlal v23.4s, v7.4h, v20.4h\n"
+ "smlal2 v24.4s, v7.8h, v20.8h\n"
+ "smlal v21.4s, v8.4h, v15.4h\n"
+ "smlal2 v22.4s, v8.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v18.4h\n"
+ "ld1 {v16.8b}, [x14], %[input_depth]\n"
+ "smlal2 v24.4s, v8.8h, v18.8h\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "ld1 {v18.8b}, [x15], %[input_depth]\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "ld1 {v19.8b}, [x15], %[input_depth]\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "ld1 {v20.8b}, [x15], %[input_depth]\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w5, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last 2 columns if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n"
+ // Mul-add left outputs.
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v14.4h\n"
+ "smlal2 v24.4s, v2.8h, v14.8h\n"
+ "smlal v21.4s, v3.4h, v12.4h\n"
+ "smlal2 v22.4s, v3.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13]\n"
+ "smlal v23.4s, v3.4h, v15.4h\n"
+ "smlal2 v24.4s, v3.8h, v15.8h\n"
+ "smlal v21.4s, v4.4h, v13.4h\n"
+ "smlal2 v22.4s, v4.8h, v13.8h\n"
+ "smlal v23.4s, v4.4h, v16.4h\n"
+ "smlal2 v24.4s, v4.8h, v16.8h\n"
+ "smlal v21.4s, v5.4h, v14.4h\n"
+ "smlal2 v22.4s, v5.8h, v14.8h\n"
+ "smlal v23.4s, v5.4h, v17.4h\n"
+ "smlal2 v24.4s, v5.8h, v17.8h\n"
+ "smlal v21.4s, v6.4h, v15.4h\n"
+ "smlal2 v22.4s, v6.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x14]\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x15]\n"
+ "smlal v21.4s, v7.4h, v16.4h\n"
+ "smlal2 v22.4s, v7.8h, v16.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+
+ // Mul-add right outputs.
+ "smlal v21.4s, v0.4h, v10.4h\n"
+ "smlal2 v22.4s, v0.8h, v10.8h\n"
+ "smlal v23.4s, v0.4h, v13.4h\n"
+ "smlal2 v24.4s, v0.8h, v13.8h\n"
+ "smlal v21.4s, v1.4h, v11.4h\n"
+ "smlal2 v22.4s, v1.8h, v11.8h\n"
+ "smlal v23.4s, v1.4h, v14.4h\n"
+ "smlal2 v24.4s, v1.8h, v14.8h\n"
+ "smlal v21.4s, v2.4h, v9.4h\n"
+ "smlal2 v22.4s, v2.8h, v9.8h\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "smlal v21.4s, v5.4h, v12.4h\n"
+ "smlal2 v22.4s, v5.8h, v12.8h\n"
+ "smlal v23.4s, v5.4h, v15.4h\n"
+ "smlal2 v24.4s, v5.8h, v15.8h\n"
+ "smlal v21.4s, v6.4h, v16.4h\n"
+ "smlal2 v22.4s, v6.8h, v16.8h\n"
+ "smlal v23.4s, v6.4h, v19.4h\n"
+ "smlal2 v24.4s, v6.8h, v19.8h\n"
+ "smlal v21.4s, v7.4h, v17.4h\n"
+ "smlal2 v22.4s, v7.8h, v17.8h\n"
+ "smlal v23.4s, v7.4h, v20.4h\n"
+ "smlal2 v24.4s, v7.8h, v20.8h\n"
+ "smlal v21.4s, v8.4h, v15.4h\n"
+ "smlal2 v22.4s, v8.8h, v15.8h\n"
+ "smlal v23.4s, v8.4h, v18.4h\n"
+ "smlal2 v24.4s, v8.8h, v18.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v14.4h\n"
+ "smlal2 v24.4s, v2.8h, v14.8h\n"
+ "smlal v21.4s, v3.4h, v12.4h\n"
+ "smlal2 v22.4s, v3.8h, v12.8h\n"
+ "smlal v23.4s, v3.4h, v15.4h\n"
+ "smlal2 v24.4s, v3.8h, v15.8h\n"
+ "smlal v21.4s, v4.4h, v13.4h\n"
+ "smlal2 v22.4s, v4.8h, v13.8h\n"
+ "smlal v23.4s, v4.4h, v16.4h\n"
+ "smlal2 v24.4s, v4.8h, v16.8h\n"
+ "smlal v21.4s, v5.4h, v14.4h\n"
+ "smlal2 v22.4s, v5.8h, v14.8h\n"
+ "smlal v23.4s, v5.4h, v17.4h\n"
+ "smlal2 v24.4s, v5.8h, v17.8h\n"
+ "smlal v21.4s, v6.4h, v15.4h\n"
+ "smlal2 v22.4s, v6.8h, v15.8h\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "smlal v21.4s, v7.4h, v16.4h\n"
+ "smlal2 v22.4s, v7.8h, v16.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v9.16b, v21.16b, v28.16b\n"
+ "and v12.16b, v22.16b, v28.16b\n"
+ "and v15.16b, v23.16b, v28.16b\n"
+ "and v18.16b, v24.16b, v28.16b\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v12.4s, v12.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v9.4s\n"
+ "sqadd v22.4s, v22.4s, v12.4s\n"
+ "sqadd v23.4s, v23.4s, v15.4s\n"
+ "sqadd v24.4s, v24.4s, v18.4s\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "st1 {v23.8b}, [x7], x3\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n"
+ "subs %w[output_window_height], %w[output_window_height], #2\n"
+ "add %[input_ptr], %[input_ptr], %[input_height_increment]\n"
+ "cmp %w[output_window_height], #2\n"
+ "add %[output_ptr], %[output_ptr], %[output_height_increment]\n"
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n"
+ "cmp %w[output_window_height], #1\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_1 ":\n"
+ "mov x12, %[input_ptr]\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "add x13, %[input_ptr], %[input_row_size]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "add x14, x13, %[input_row_size]\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x15, x14, %[input_row_size]\n"
+ "mov w5, %w[output_window_width]\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "add x7, %[output_ptr], x1\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ // The height 1 / width 2 loop loads an extra 1x1 output in anticipation
+ // for the next iteration. Make sure |output_window_width| is large
+ // enough to handle the additional load, otherwise jump to the
+ // appropriate label to handle smaller widths.
+ "cmp w5, #2\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+ "ld1 {v18.8b}, [x14], %[input_depth]\n"
+ "ld1 {v19.8b}, [x14], %[input_depth]\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "ld1 {v24.4s}, [x10]\n"
+
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n"
+ "cmp w5, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n"
+ // Load inputs for 3x4 input window which corresponds to a 1x2 output
+ // window.
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v16.8b}, [x13]\n"
+ "smlal v23.4s, v0.4h, v10.4h\n"
+ "ld1 {v20.8b}, [x14]\n"
+ "smlal2 v24.4s, v0.8h, v10.8h\n"
+ "subs w5, w5, #2\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "cmp w5, #3\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "add %[input_ptr], %[input_ptr], %[input_width_increment]\n"
+ "smlal v23.4s, v1.4h, v11.4h\n"
+ "mov x12, %[input_ptr]\n"
+ "smlal2 v24.4s, v1.8h, v11.8h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x13, %[input_ptr], %[input_row_size]\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "add x14, x13, %[input_row_size]\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "add x15, x14, %[input_row_size]\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v3.4h, v14.4h\n"
+ "smlal2 v24.4s, v3.8h, v14.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v4.4h, v15.4h\n"
+ "smlal2 v24.4s, v4.8h, v15.8h\n"
+ "smlal v21.4s, v5.4h, v15.4h\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "smlal2 v22.4s, v5.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v16.4h\n"
+ "smlal2 v24.4s, v5.8h, v16.8h\n"
+ "smlal v21.4s, v6.4h, v17.4h\n"
+ "smlal2 v22.4s, v6.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "smlal v21.4s, v7.4h, v18.4h\n"
+ "smlal2 v22.4s, v7.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v19.4h\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+ "smlal2 v22.4s, v8.8h, v19.8h\n"
+ "ld1 {v19.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [%[output_ptr]], x3\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "st1 {v23.8b}, [%[output_ptr]], x3\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w5, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last two horizontal outputs if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v16.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v0.4h, v10.4h\n"
+ "ld1 {v20.8b}, [x14], %[input_depth]\n"
+ "smlal2 v24.4s, v0.8h, v10.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v11.4h\n"
+ "smlal2 v24.4s, v1.8h, v11.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v23.4s, v3.4h, v14.4h\n"
+ "smlal2 v24.4s, v3.8h, v14.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v23.4s, v4.4h, v15.4h\n"
+ "smlal2 v24.4s, v4.8h, v15.8h\n"
+ "smlal v21.4s, v5.4h, v15.4h\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "smlal2 v22.4s, v5.8h, v15.8h\n"
+ "smlal v23.4s, v5.4h, v16.4h\n"
+ "smlal2 v24.4s, v5.8h, v16.8h\n"
+ "smlal v21.4s, v6.4h, v17.4h\n"
+ "smlal2 v22.4s, v6.8h, v17.8h\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "smlal v21.4s, v7.4h, v18.4h\n"
+ "smlal2 v22.4s, v7.8h, v18.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v19.4h\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+ "smlal2 v22.4s, v8.8h, v19.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v21.8b}, [%[output_ptr]], x3\n"
+ "st1 {v23.8b}, [%[output_ptr]], x3\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ // Handle bottom right output if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v21.4s, v5.4h, v15.4h\n"
+ "smlal2 v22.4s, v5.8h, v15.8h\n"
+ "smlal v21.4s, v6.4h, v17.4h\n"
+ "smlal2 v22.4s, v6.8h, v17.8h\n"
+ "smlal v21.4s, v7.4h, v18.4h\n"
+ "smlal2 v22.4s, v7.8h, v18.8h\n"
+ "smlal v21.4s, v8.4h, v19.4h\n"
+ "smlal2 v22.4s, v8.8h, v19.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "and v9.16b, v21.16b, v28.16b\n"
+ "and v12.16b, v22.16b, v28.16b\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v12.4s, v12.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v9.4s\n"
+ "sqadd v22.4s, v22.4s, v12.4s\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "st1 {v21.8b}, [%[output_ptr]]\n"
+ DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr),
+ [output_window_height] "+r"(output_window_height)
+ :
+ // Inputs.
+ [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size),
+ [input_depth] "r"(input_depth),
+ [output_window_width] "r"(output_window_width),
+ [input_width_increment] "r"(input_width_increment),
+ [input_height_increment] "r"(input_height_increment),
+ [output_height_increment] "r"(output_height_increment),
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
+ "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_END
}
};
template <>
-struct ConvKernel3x3FilterDepth8<4, 4, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- // Reuse 4x2 kernel twice.
- ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth,
- output_width);
-
- ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run(
- input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr + 2 * output_depth, output_depth, output_width);
+struct DepthwiseConvWindow<8, 2, 2> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr, int64_t input_depth,
+ int64_t input_row_size, int32 output_window_height,
+ int32 output_window_width,
+ const DepthwiseConvParams* params_ptr) {
+ const int64_t input_width_increment = 4 * input_depth;
+ const int64_t input_height_increment = 4 * input_row_size;
+ const int64_t output_height_increment = 2 * params_ptr->output_row_size;
+
+#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6"
+#define DEPTHWISECONV_LABEL_HEIGHT_1 "7"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11"
+
+ asm volatile(
+ // Performs depthwise convolutions for a window specified by
+ // |output_window_height| and |output_window_width|. The inner-most loop
+ // processes 2x2 outputs, and any leftovers at the end.
+ //
+ // Algorithm works as follows:
+ //
+ // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter
+ // values.
+ // 2. For 2 output heights at a time:
+ // i. For 2 output widths at a time at stride 2, a 5x5 input
+ // window is required. To avoid register exhaustion, we load
+ // the first 2 rows of the 5x5 input window into registers
+ // v9--v18, and use the same registers to load the next 2
+ // rows, and finally v9--v13 to load the last row.
+ // Accumulators for all 2x2 outputs are reserved by registers
+ // v21-v22 (top left output), v23-v24 (top right output),
+ // v19-v20 (bottom left output), v25-v26 (bottom right
+ // output).
+ // ii. Handle single leftover width if exists.
+ // 3. Handle single leftover height if exists.
+ // i. For 2 output widths at a time at stride 2, load inputs for
+ // a 1x2 (1 height, 2 width) output window (3x5 input
+ // window). Registers v9--v24 hold input values. Mul-add with
+ // accumulators v24--v27.
+ // ii. Handle single leftover width if exists.
+ //
+ // Loads are placed as soon as the register is no longer needed and
+ // interleaved with arithmetic operations to take advantage of
+ // dual-issue pipelines. We also add input offsets as far from the loads
+ // as possible to give loads enough cycles to fetch data from memory.
+
+ // Set "constant" registers. These registers may be replaced with temp
+ // values from time to time when there are not enough NEON registers.
+ // We use x9--x15 general purpose registers as they are caller-saved
+ // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
+ "ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "cmp %w[output_window_height], #2\n"
+ "dup v28.8h, w0\n"
+ "neg w9, w9\n"
+ "ldr w1, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.4s, w9\n"
+ "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w1\n"
+ "ldr w3, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "dup v29.4s, w2\n"
+ "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w3\n"
+ "ldr x5, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "dup v31.4s, w4\n"
+ "ldr x19, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n"
+ "ldr w20, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+
+ // Load filters and add offsets.
+ "add x10, %[bias_ptr], #16\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], x5\n"
+ "dup v9.8h, w20\n"
+ "ld1 {v1.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v0.8h, v9.8h, v0.8b\n"
+ "ld1 {v2.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v1.8h, v9.8h, v1.8b\n"
+ "ld1 {v3.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v2.8h, v9.8h, v2.8b\n"
+ "ld1 {v4.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v3.8h, v9.8h, v3.8b\n"
+ "ld1 {v5.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v4.8h, v9.8h, v4.8b\n"
+ "ld1 {v6.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v5.8h, v9.8h, v5.8b\n"
+ "ld1 {v7.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v6.8h, v9.8h, v6.8b\n"
+ "ld1 {v8.8b}, [%[filter_ptr]]\n"
+ "uaddw v7.8h, v9.8h, v7.8b\n"
+ "uaddw v8.8h, v9.8h, v8.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n"
+ // Load the first two rows of the 5x5 input window, then reuse the
+ // same registers to load subsequent rows as they become available.
+ "mov x11, %[input_ptr]\n"
+ "mov x12, x11\n"
+ "add x13, x12, %[input_row_size]\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "mov w14, %w[output_window_width]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ // The height 2 / width 2 loop loads an extra 1 output horizontally in
+ // anticipation for the next iteration. Make sure
+ // |output_window_width| is large enough to handle the additional
+ // load, otherwise jump to the appropriate label to handle smaller
+ // widths.
+ "cmp w14, #2\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x15, x13, %[input_row_size]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "add x7, %[output_ptr], x19\n"
+ "ld1 {v16.8b}, [x13], %[input_depth]\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "ld1 {v19.4s}, [%[bias_ptr]]\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "ld1 {v20.4s}, [x10]\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "ld1 {v25.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "ld1 {v26.4s}, [x10]\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n"
+ "cmp w14, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v13.8b}, [x12]\n"
+ "add x12, x15, %[input_row_size]\n"
+ "smlal v23.4s, v0.4h, v11.4h\n"
+ "ld1 {v17.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v0.8h, v11.8h\n"
+ "ld1 {v18.8b}, [x13]\n"
+ "add x13, x12, %[input_row_size]\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "ld1 {v9.8b}, [x15], %[input_depth]\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v3.4h, v14.4h\n"
+ "smlal2 v22.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "subs w14, w14, #2\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "cmp w14, #3\n"
+ "smlal v21.4s, v4.4h, v15.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v22.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v5.4h, v16.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v22.4s, v5.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v1.4h, v12.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v24.4s, v1.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x15], %[input_depth]\n"
+ "smlal v23.4s, v2.4h, v13.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v24.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x15]\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v18.4h\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "smlal2 v24.4s, v5.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x12]\n"
+
+ "smlal v21.4s, v6.4h, v9.4h\n"
+ "smlal2 v22.4s, v6.8h, v9.8h\n"
+ "smlal v19.4s, v0.4h, v9.4h\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "smlal2 v20.4s, v0.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v6.4h, v11.4h\n"
+ "smlal2 v24.4s, v6.8h, v11.8h\n"
+ "smlal v21.4s, v7.4h, v10.4h\n"
+ "smlal2 v22.4s, v7.8h, v10.8h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal v19.4s, v1.4h, v10.4h\n"
+ "smlal2 v20.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v7.4h, v12.4h\n"
+ "smlal2 v24.4s, v7.8h, v12.8h\n"
+ "smlal v25.4s, v1.4h, v12.4h\n"
+ "smlal2 v26.4s, v1.8h, v12.8h\n"
+ "smlal v21.4s, v8.4h, v11.4h\n"
+ "smlal2 v22.4s, v8.8h, v11.8h\n"
+ "add x11, x11, %[input_width_increment]\n"
+ "smlal v19.4s, v2.4h, v11.4h\n"
+ "mov x12, x11\n"
+ "smlal2 v20.4s, v2.8h, v11.8h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal v25.4s, v0.4h, v11.4h\n"
+ "smlal2 v26.4s, v0.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v13.4h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v8.8h, v13.8h\n"
+ "smlal v25.4s, v2.4h, v13.4h\n"
+ "smlal2 v26.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13]\n"
+ "add x13, x12, %[input_row_size]\n"
+ "add x15, x13, %[input_row_size]\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v27.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v23.8b}, [x6], x5\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+
+ "smlal v19.4s, v6.4h, v9.4h\n"
+ "smlal2 v20.4s, v6.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal v25.4s, v6.4h, v11.4h\n"
+ "smlal2 v26.4s, v6.8h, v11.8h\n"
+ "smlal v19.4s, v7.4h, v10.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v20.4s, v7.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal v25.4s, v7.4h, v12.4h\n"
+ "smlal2 v26.4s, v7.8h, v12.8h\n"
+ "smlal v19.4s, v8.4h, v11.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v20.4s, v8.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "smlal v25.4s, v8.4h, v13.4h\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "smlal2 v26.4s, v8.8h, v13.8h\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "smlal v19.4s, v3.4h, v14.4h\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "smlal2 v20.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v25.4s, v3.4h, v16.4h\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "smlal2 v26.4s, v3.8h, v16.8h\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "smlal v19.4s, v4.4h, v15.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v20.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "smlal v25.4s, v4.4h, v17.4h\n"
+ "smlal2 v26.4s, v4.8h, v17.8h\n"
+ "smlal v19.4s, v5.4h, v16.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v20.4s, v5.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x13], %[input_depth]\n"
+ "smlal v25.4s, v5.4h, v18.4h\n"
+ "smlal2 v26.4s, v5.8h, v18.8h\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v19.4s, v19.4s, v27.4s\n"
+ "sqrdmulh v20.4s, v20.4s, v27.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v27.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v27.4s\n"
+ "and v27.16b, v19.16b, v28.16b\n"
+ "and v29.16b, v20.16b, v28.16b\n"
+ "and v30.16b, v25.16b, v28.16b\n"
+ "and v31.16b, v26.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v19.4s, v19.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v20.4s, v20.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v25.4s, v25.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v26.4s, v26.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v19.4s, v19.4s, v28.4s\n"
+ "srshl v20.4s, v20.4s, v28.4s\n"
+ "srshl v25.4s, v25.4s, v28.4s\n"
+ "srshl v26.4s, v26.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v19.4s, v19.4s, v29.4s\n"
+ "add v20.4s, v20.4s, v29.4s\n"
+ "add v25.4s, v25.4s, v29.4s\n"
+ "add v26.4s, v26.4s, v29.4s\n"
+ "smax v19.4s, v19.4s, v30.4s\n"
+ "smax v20.4s, v20.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smin v19.4s, v19.4s, v31.4s\n"
+ "smin v20.4s, v20.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "sqxtn v19.4h, v19.4s\n"
+ "sqxtn v25.4h, v25.4s\n"
+ "sqxtn2 v19.8h, v20.4s\n"
+ "ld1 {v20.4s}, [x10]\n"
+ "sqxtn2 v25.8h, v26.4s\n"
+ "ld1 {v26.4s}, [x10]\n"
+ "sqxtun v19.8b, v19.8h\n"
+ "sqxtun v25.8b, v25.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v19.8b}, [x7], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v25.8b}, [x7], x5\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "ld1 {v19.4s}, [%[bias_ptr]]\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "ld1 {v25.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w14, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last 2 columns if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v13.8b}, [x12]\n"
+ "add x12, x15, %[input_row_size]\n"
+ "smlal v23.4s, v0.4h, v11.4h\n"
+ "ld1 {v17.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v0.8h, v11.8h\n"
+ "ld1 {v18.8b}, [x13]\n"
+ "add x13, x12, %[input_row_size]\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "ld1 {v9.8b}, [x15], %[input_depth]\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v3.4h, v14.4h\n"
+ "smlal2 v22.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "smlal v21.4s, v4.4h, v15.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v22.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v5.4h, v16.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v22.4s, v5.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v1.4h, v12.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v24.4s, v1.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x15], %[input_depth]\n"
+ "smlal v23.4s, v2.4h, v13.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v24.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x15]\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v18.4h\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "smlal2 v24.4s, v5.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x12]\n"
+
+ "smlal v21.4s, v6.4h, v9.4h\n"
+ "smlal2 v22.4s, v6.8h, v9.8h\n"
+ "smlal v19.4s, v0.4h, v9.4h\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "smlal2 v20.4s, v0.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v6.4h, v11.4h\n"
+ "smlal2 v24.4s, v6.8h, v11.8h\n"
+ "smlal v21.4s, v7.4h, v10.4h\n"
+ "smlal2 v22.4s, v7.8h, v10.8h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal v19.4s, v1.4h, v10.4h\n"
+ "smlal2 v20.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v7.4h, v12.4h\n"
+ "smlal2 v24.4s, v7.8h, v12.8h\n"
+ "smlal v25.4s, v1.4h, v12.4h\n"
+ "smlal2 v26.4s, v1.8h, v12.8h\n"
+ "smlal v21.4s, v8.4h, v11.4h\n"
+ "smlal2 v22.4s, v8.8h, v11.8h\n"
+ "smlal v19.4s, v2.4h, v11.4h\n"
+ "smlal2 v20.4s, v2.8h, v11.8h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal v25.4s, v0.4h, v11.4h\n"
+ "smlal2 v26.4s, v0.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v13.4h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v8.8h, v13.8h\n"
+ "smlal v25.4s, v2.4h, v13.4h\n"
+ "smlal2 v26.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13]\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v27.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v23.8b}, [x6]\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+
+ "smlal v19.4s, v6.4h, v9.4h\n"
+ "smlal2 v20.4s, v6.8h, v9.8h\n"
+ "smlal v25.4s, v6.4h, v11.4h\n"
+ "smlal2 v26.4s, v6.8h, v11.8h\n"
+ "smlal v19.4s, v7.4h, v10.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v20.4s, v7.8h, v10.8h\n"
+ "smlal v25.4s, v7.4h, v12.4h\n"
+ "smlal2 v26.4s, v7.8h, v12.8h\n"
+ "smlal v19.4s, v8.4h, v11.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v20.4s, v8.8h, v11.8h\n"
+ "smlal v25.4s, v8.4h, v13.4h\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "smlal2 v26.4s, v8.8h, v13.8h\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "smlal v19.4s, v3.4h, v14.4h\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "smlal2 v20.4s, v3.8h, v14.8h\n"
+ "smlal v25.4s, v3.4h, v16.4h\n"
+ "smlal2 v26.4s, v3.8h, v16.8h\n"
+ "smlal v19.4s, v4.4h, v15.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v20.4s, v4.8h, v15.8h\n"
+ "smlal v25.4s, v4.4h, v17.4h\n"
+ "smlal2 v26.4s, v4.8h, v17.8h\n"
+ "smlal v19.4s, v5.4h, v16.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v20.4s, v5.8h, v16.8h\n"
+ "smlal v25.4s, v5.4h, v18.4h\n"
+ "smlal2 v26.4s, v5.8h, v18.8h\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v19.4s, v19.4s, v27.4s\n"
+ "sqrdmulh v20.4s, v20.4s, v27.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v27.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v27.4s\n"
+ "and v27.16b, v19.16b, v28.16b\n"
+ "and v29.16b, v20.16b, v28.16b\n"
+ "and v30.16b, v25.16b, v28.16b\n"
+ "and v31.16b, v26.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v19.4s, v19.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v20.4s, v20.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v25.4s, v25.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v26.4s, v26.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v19.4s, v19.4s, v28.4s\n"
+ "srshl v20.4s, v20.4s, v28.4s\n"
+ "srshl v25.4s, v25.4s, v28.4s\n"
+ "srshl v26.4s, v26.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v19.4s, v19.4s, v29.4s\n"
+ "add v20.4s, v20.4s, v29.4s\n"
+ "add v25.4s, v25.4s, v29.4s\n"
+ "add v26.4s, v26.4s, v29.4s\n"
+ "smax v19.4s, v19.4s, v30.4s\n"
+ "smax v20.4s, v20.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smin v19.4s, v19.4s, v31.4s\n"
+ "smin v20.4s, v20.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "sqxtn v19.4h, v19.4s\n"
+ "sqxtn v25.4h, v25.4s\n"
+ "sqxtn2 v19.8h, v20.4s\n"
+ "sqxtn2 v25.8h, v26.4s\n"
+ "sqxtun v19.8b, v19.8h\n"
+ "sqxtun v25.8b, v25.8h\n"
+ "st1 {v19.8b}, [x7], x5\n"
+ "st1 {v25.8b}, [x7]\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n"
+
+ // Handle last column if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n"
+ // Registers v9, v10, v11, v14, v15, and v16 have already been loaded
+ // with the correct values at this point. This corresponds to the
+ // first two input rows of the top left output. Now load the last
+ // input row for this output. Once these inputs are no longer needed,
+ // load the input rows for the bottom left output.
+ "add x12, x15, %[input_row_size]\n"
+ "add x13, x12, %[input_row_size]\n"
+
+ "ld1 {v12.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v13.8b}, [x15], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v17.8b}, [x15]\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x12]\n"
+ "smlal v21.4s, v3.4h, v14.4h\n"
+ "smlal2 v22.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v21.4s, v4.4h, v15.4h\n"
+ "smlal2 v22.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "smlal v21.4s, v5.4h, v16.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v22.4s, v5.8h, v16.8h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "ld1 {v16.8b}, [x13]\n"
+
+ "smlal v21.4s, v6.4h, v12.4h\n"
+ "smlal2 v22.4s, v6.8h, v12.8h\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v7.4h, v13.4h\n"
+ "smlal2 v22.4s, v7.8h, v13.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v2.4h, v17.4h\n"
+ "smlal2 v24.4s, v2.8h, v17.8h\n"
+
+ "dup v26.4s, w9\n"
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "and v18.16b, v21.16b, v26.16b\n"
+ "and v19.16b, v22.16b, v26.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v18.4s\n"
+ "sqadd v22.4s, v22.4s, v19.4s\n"
+ "srshl v21.4s, v21.4s, v26.4s\n"
+ "srshl v22.4s, v22.4s, v26.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6]\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+
+ "smlal v23.4s, v3.4h, v9.4h\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "smlal2 v24.4s, v3.8h, v9.8h\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "smlal v23.4s, v4.4h, v10.4h\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "smlal2 v24.4s, v4.8h, v10.8h\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "smlal v23.4s, v5.4h, v11.4h\n"
+ "smlal2 v24.4s, v5.8h, v11.8h\n"
+
+ "smlal v23.4s, v6.4h, v14.4h\n"
+ "smlal2 v24.4s, v6.8h, v14.8h\n"
+ "smlal v23.4s, v7.4h, v15.4h\n"
+ "smlal2 v24.4s, v7.8h, v15.8h\n"
+ "smlal v23.4s, v8.4h, v16.4h\n"
+ "smlal2 v24.4s, v8.8h, v16.8h\n"
+
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v18.16b, v23.16b, v26.16b\n"
+ "and v19.16b, v24.16b, v26.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v23.4s, v23.4s, v18.4s\n"
+ "sqadd v24.4s, v24.4s, v19.4s\n"
+ "srshl v23.4s, v23.4s, v26.4s\n"
+ "srshl v24.4s, v24.4s, v26.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v23.8b}, [x7]\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n"
+ "subs %w[output_window_height], %w[output_window_height], #2\n"
+ "add %[input_ptr], %[input_ptr], %[input_height_increment]\n"
+ "cmp %w[output_window_height], #2\n"
+ "add %[output_ptr], %[output_ptr], %[output_height_increment]\n"
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n"
+ "cmp %w[output_window_height], #1\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_1 ":\n"
+ "mov x11, %[input_ptr]\n"
+ "mov x12, x11\n"
+ "add x13, x12, %[input_row_size]\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "add x15, x13, %[input_row_size]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "mov w14, %w[output_window_width]\n"
+ // The height 1 / width 2 loop loads an extra 1x1 output in anticipation
+ // for the next iteration. Make sure |output_window_width| is large
+ // enough to handle the additional load, otherwise jump to the
+ // appropriate label to handle smaller widths.
+ "cmp w14, #2\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "ld1 {v15.8b}, [x15], %[input_depth]\n"
+ "ld1 {v16.8b}, [x15], %[input_depth]\n"
+ "ld1 {v17.8b}, [x15], %[input_depth]\n"
+
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "ld1 {v24.4s}, [%[bias_ptr]]\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "ld1 {v25.4s}, [x10]\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "ld1 {v26.4s}, [%[bias_ptr]]\n"
+ "ld1 {v27.4s}, [x10]\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n"
+ "cmp w14, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n"
+ "smlal v24.4s, v0.4h, v9.4h\n"
+ "ld1 {v18.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v0.8h, v9.8h\n"
+ "ld1 {v19.8b}, [x12]\n"
+ "smlal v26.4s, v0.4h, v11.4h\n"
+ "ld1 {v20.8b}, [x13], %[input_depth]\n"
+ "smlal2 v27.4s, v0.8h, v11.8h\n"
+ "ld1 {v21.8b}, [x13]\n"
+ "smlal v24.4s, v1.4h, v10.4h\n"
+ "ld1 {v22.8b}, [x15], %[input_depth]\n"
+ "smlal2 v25.4s, v1.8h, v10.8h\n"
+ "ld1 {v23.8b}, [x15]\n"
+ "smlal v24.4s, v2.4h, v11.4h\n"
+ "subs w14, w14, #2\n"
+ "smlal2 v25.4s, v2.8h, v11.8h\n"
+ "cmp w14, #3\n"
+ "smlal v24.4s, v3.4h, v12.4h\n"
+ "add x11, x11, %[input_width_increment]\n"
+ "smlal2 v25.4s, v3.8h, v12.8h\n"
+ "mov x12, x11\n"
+ "smlal v26.4s, v3.4h, v14.4h\n"
+ "add x13, x12, %[input_row_size]\n"
+ "smlal2 v27.4s, v3.8h, v14.8h\n"
+ "add x15, x13, %[input_row_size]\n"
+ "smlal v24.4s, v4.4h, v13.4h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v4.8h, v13.8h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal v24.4s, v5.4h, v14.4h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v5.8h, v14.8h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal v24.4s, v6.4h, v15.4h\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "smlal2 v25.4s, v6.8h, v15.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v26.4s, v6.4h, v17.4h\n"
+ "ld1 {v15.8b}, [x15], %[input_depth]\n"
+ "smlal2 v27.4s, v6.8h, v17.8h\n"
+ "smlal v24.4s, v7.4h, v16.4h\n"
+ "smlal2 v25.4s, v7.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x15], %[input_depth]\n"
+ "smlal v24.4s, v8.4h, v17.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v25.4s, v8.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x15], %[input_depth]\n"
+ "uaddw v19.8h, v28.8h, v19.8b\n"
+
+ "smlal v26.4s, v1.4h, v18.4h\n"
+ "uaddw v20.8h, v28.8h, v20.8b\n"
+ "smlal2 v27.4s, v1.8h, v18.8h\n"
+ "smlal v26.4s, v2.4h, v19.4h\n"
+ "uaddw v21.8h, v28.8h, v21.8b\n"
+ "smlal2 v27.4s, v2.8h, v19.8h\n"
+ "smlal v26.4s, v4.4h, v20.4h\n"
+ "smlal v26.4s, v5.4h, v21.4h\n"
+ "smlal2 v27.4s, v4.8h, v20.8h\n"
+ "uaddw v22.8h, v28.8h, v22.8b\n"
+ "smlal2 v27.4s, v5.8h, v21.8h\n"
+ "uaddw v23.8h, v28.8h, v23.8b\n"
+ "smlal v26.4s, v7.4h, v22.4h\n"
+ "smlal2 v27.4s, v7.8h, v22.8h\n"
+ "smlal v26.4s, v8.4h, v23.4h\n"
+ "smlal2 v27.4s, v8.8h, v23.8h\n"
+
+ "dup v28.4s, w1\n"
+ "dup v29.4s, w9\n"
+ "sqrdmulh v24.4s, v24.4s, v28.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v28.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v28.4s\n"
+ "sqrdmulh v27.4s, v27.4s, v28.4s\n"
+ "dup v28.4s, w2\n"
+ "and v30.16b, v24.16b, v29.16b\n"
+ "and v31.16b, v25.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v30.4s\n"
+ "sqadd v25.4s, v25.4s, v31.4s\n"
+ "and v30.16b, v26.16b, v29.16b\n"
+ "and v31.16b, v27.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v26.4s, v26.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v27.4s, v27.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v24.4s, v24.4s, v29.4s\n"
+ "srshl v25.4s, v25.4s, v29.4s\n"
+ "srshl v26.4s, v26.4s, v29.4s\n"
+ "srshl v27.4s, v27.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v28.4s\n"
+ "add v25.4s, v25.4s, v28.4s\n"
+ "add v26.4s, v26.4s, v28.4s\n"
+ "add v27.4s, v27.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smax v27.4s, v27.4s, v30.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "smin v27.4s, v27.4s, v31.4s\n"
+ "sqxtn v24.4h, v24.4s\n"
+ "sqxtn v26.4h, v26.4s\n"
+ "sqxtn2 v24.8h, v25.4s\n"
+ "ld1 {v25.4s}, [x10]\n"
+ "sqxtn2 v26.8h, v27.4s\n"
+ "ld1 {v27.4s}, [x10]\n"
+ "sqxtun v24.8b, v24.8h\n"
+ "sqxtun v26.8b, v26.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v24.8b}, [x6], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v26.8b}, [x6], x5\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "ld1 {v24.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "ld1 {v26.4s}, [%[bias_ptr]]\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w14, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last two horizontal outputs if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n"
+ "smlal v24.4s, v0.4h, v9.4h\n"
+ "ld1 {v18.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v0.8h, v9.8h\n"
+ "ld1 {v19.8b}, [x12]\n"
+ "smlal v26.4s, v0.4h, v11.4h\n"
+ "ld1 {v20.8b}, [x13], %[input_depth]\n"
+ "smlal2 v27.4s, v0.8h, v11.8h\n"
+ "ld1 {v21.8b}, [x13]\n"
+ "smlal v24.4s, v1.4h, v10.4h\n"
+ "ld1 {v22.8b}, [x15], %[input_depth]\n"
+ "smlal2 v25.4s, v1.8h, v10.8h\n"
+ "ld1 {v23.8b}, [x15]\n"
+ "smlal v24.4s, v2.4h, v11.4h\n"
+ "smlal2 v25.4s, v2.8h, v11.8h\n"
+ "smlal v24.4s, v3.4h, v12.4h\n"
+ "smlal2 v25.4s, v3.8h, v12.8h\n"
+ "smlal v26.4s, v3.4h, v14.4h\n"
+ "smlal2 v27.4s, v3.8h, v14.8h\n"
+ "smlal v24.4s, v4.4h, v13.4h\n"
+ "smlal2 v25.4s, v4.8h, v13.8h\n"
+ "smlal v24.4s, v5.4h, v14.4h\n"
+ "smlal2 v25.4s, v5.8h, v14.8h\n"
+ "smlal v24.4s, v6.4h, v15.4h\n"
+ "smlal2 v25.4s, v6.8h, v15.8h\n"
+ "smlal v26.4s, v6.4h, v17.4h\n"
+ "smlal2 v27.4s, v6.8h, v17.8h\n"
+ "smlal v24.4s, v7.4h, v16.4h\n"
+ "smlal2 v25.4s, v7.8h, v16.8h\n"
+ "smlal v24.4s, v8.4h, v17.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v25.4s, v8.8h, v17.8h\n"
+ "uaddw v19.8h, v28.8h, v19.8b\n"
+
+ "smlal v26.4s, v1.4h, v18.4h\n"
+ "uaddw v20.8h, v28.8h, v20.8b\n"
+ "smlal2 v27.4s, v1.8h, v18.8h\n"
+ "smlal v26.4s, v2.4h, v19.4h\n"
+ "uaddw v21.8h, v28.8h, v21.8b\n"
+ "smlal2 v27.4s, v2.8h, v19.8h\n"
+ "smlal v26.4s, v4.4h, v20.4h\n"
+ "smlal v26.4s, v5.4h, v21.4h\n"
+ "smlal2 v27.4s, v4.8h, v20.8h\n"
+ "uaddw v22.8h, v28.8h, v22.8b\n"
+ "smlal2 v27.4s, v5.8h, v21.8h\n"
+ "uaddw v23.8h, v28.8h, v23.8b\n"
+ "smlal v26.4s, v7.4h, v22.4h\n"
+ "smlal2 v27.4s, v7.8h, v22.8h\n"
+ "smlal v26.4s, v8.4h, v23.4h\n"
+ "smlal2 v27.4s, v8.8h, v23.8h\n"
+
+ "dup v28.4s, w1\n"
+ "dup v29.4s, w9\n"
+ "sqrdmulh v24.4s, v24.4s, v28.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v28.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v28.4s\n"
+ "sqrdmulh v27.4s, v27.4s, v28.4s\n"
+ "dup v28.4s, w2\n"
+ "and v30.16b, v24.16b, v29.16b\n"
+ "and v31.16b, v25.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v30.4s\n"
+ "sqadd v25.4s, v25.4s, v31.4s\n"
+ "and v30.16b, v26.16b, v29.16b\n"
+ "and v31.16b, v27.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v26.4s, v26.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v27.4s, v27.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v24.4s, v24.4s, v29.4s\n"
+ "srshl v25.4s, v25.4s, v29.4s\n"
+ "srshl v26.4s, v26.4s, v29.4s\n"
+ "srshl v27.4s, v27.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v28.4s\n"
+ "add v25.4s, v25.4s, v28.4s\n"
+ "add v26.4s, v26.4s, v28.4s\n"
+ "add v27.4s, v27.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smax v27.4s, v27.4s, v30.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "smin v27.4s, v27.4s, v31.4s\n"
+ "sqxtn v24.4h, v24.4s\n"
+ "sqxtn v26.4h, v26.4s\n"
+ "sqxtn2 v24.8h, v25.4s\n"
+ "sqxtn2 v26.8h, v27.4s\n"
+ "sqxtun v24.8b, v24.8h\n"
+ "sqxtun v26.8b, v26.8h\n"
+ "st1 {v24.8b}, [x6], x5\n"
+ "st1 {v26.8b}, [x6]\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ // Handle bottom right output if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n"
+ "dup v26.4s, w9\n"
+ "dup v27.4s, w1\n"
+ "dup v29.4s, w2\n"
+
+ "smlal v24.4s, v0.4h, v9.4h\n"
+ "smlal2 v25.4s, v0.8h, v9.8h\n"
+ "smlal v24.4s, v1.4h, v10.4h\n"
+ "smlal2 v25.4s, v1.8h, v10.8h\n"
+ "smlal v24.4s, v2.4h, v11.4h\n"
+ "smlal2 v25.4s, v2.8h, v11.8h\n"
+ "smlal v24.4s, v3.4h, v12.4h\n"
+ "smlal2 v25.4s, v3.8h, v12.8h\n"
+ "smlal v24.4s, v4.4h, v13.4h\n"
+ "smlal2 v25.4s, v4.8h, v13.8h\n"
+ "smlal v24.4s, v5.4h, v14.4h\n"
+ "smlal2 v25.4s, v5.8h, v14.8h\n"
+ "smlal v24.4s, v6.4h, v15.4h\n"
+ "smlal2 v25.4s, v6.8h, v15.8h\n"
+ "smlal v24.4s, v7.4h, v16.4h\n"
+ "smlal2 v25.4s, v7.8h, v16.8h\n"
+ "smlal v24.4s, v8.4h, v17.4h\n"
+ "smlal2 v25.4s, v8.8h, v17.8h\n"
+
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v27.4s\n"
+ "and v18.16b, v24.16b, v26.16b\n"
+ "and v19.16b, v25.16b, v26.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v18.4s\n"
+ "sqadd v25.4s, v25.4s, v19.4s\n"
+ "srshl v24.4s, v24.4s, v26.4s\n"
+ "srshl v25.4s, v25.4s, v26.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "add v25.4s, v25.4s, v29.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "sqxtn v24.4h, v24.4s\n"
+ "sqxtn2 v24.8h, v25.4s\n"
+ "sqxtun v24.8b, v24.8h\n"
+ "st1 {v24.8b}, [x6]\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr),
+ [output_window_height] "+r"(output_window_height)
+ :
+ // Inputs.
+ [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size),
+ [input_depth] "r"(input_depth),
+ [output_window_width] "r"(output_window_width),
+ [input_width_increment] "r"(input_width_increment),
+ [input_height_increment] "r"(input_height_increment),
+ [output_height_increment] "r"(output_height_increment),
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
+ "x9", "x10", "x11", "x12", "x13", "x14", "x15",
+ "x19", "x20");
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_END
}
};
-template <>
-struct ConvKernel3x3FilterDepth8<4, 1, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- const int output_row_size = output_depth * output_width;
-
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- DotProductAndStore(
- filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3,
- input_4, input_5, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Third output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
-
- DotProductAndStore(
- filter, input_3, input_4, input_5, input_6, input_7, input_8, input_0,
- input_1, input_2, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Fourth output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 2, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- Int32x8 acc_0, acc_1, acc_2, acc_3;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_2.low = vld1q_s32(bias_ptr);
- acc_3.low = vld1q_s32(bias_ptr);
-
- bias_ptr += 4;
- acc_0.high = vld1q_s32(bias_ptr);
- acc_1.high = vld1q_s32(bias_ptr);
- acc_2.high = vld1q_s32(bias_ptr);
- acc_3.high = vld1q_s32(bias_ptr);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- // Add scope for input registers to help the compiler know that it is
- // not needed.
- {
- // To process 2x2 outputs using a 3x3 filter at stride 2, we require
- // 5x5 inputs. We load the first 5x2 inputs at a time.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9;
-
- const uint8* ptr = input_ptr;
-
- // Load inputs.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next inputs.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- // Moving onto the two bottom outputs.
- acc_2 = MultiplyAccumulateRow(acc_2, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_3 = MultiplyAccumulateRow(acc_3, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_2 = MultiplyAccumulateRow(acc_2, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_3 = MultiplyAccumulateRow(acc_3, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load last input row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
+enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- }
-
- acc_2 = MultiplyAccumulateRow(acc_2, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_3 = MultiplyAccumulateRow(acc_3, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
- }
-
- DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset,
- output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
- }
-};
+template <EdgeType kEdgeType, int kPadWidth, int kPadHeight>
+struct DepthwiseConvPartial {};
template <>
-struct ConvKernel3x3FilterDepth8<2, 4, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- // Reuse 2x2 kernel twice.
- ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth,
- output_width);
-
- ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run(
- input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr + 2 * output_depth, output_depth, output_width);
+struct DepthwiseConvPartial<EdgeType::kCenter, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 1x1 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the 1x1 input and filter values.
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr x11, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w10\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+ "cmp x11, #16\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
+ "dup v28.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w10, w10\n"
+ "dup v29.4s, w10\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w10\n"
+ "dup v25.8h, w9\n"
+
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "subs x11, x11, #8\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "cmp x11, #16\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v8", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28",
+ "v29", "v30", "v31",
+ // We use these general-purpose registers.
+ "x9", "x10", "x11");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
}
};
template <>
-struct ConvKernel3x3FilterDepth8<2, 1, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- const int output_row_size = output_depth * output_width;
-
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- DotProductAndStore(
- filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3,
- input_4, input_5, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
+struct DepthwiseConvPartial<EdgeType::kCorner, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 2x2 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the beginning of the 2x2 input and
+ // filter values.
+
+ // Load input and filter values.
+ "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "ldr x9, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n"
+ "cmp x15, #16\n"
+ "add x12, %[input_ptr], x15\n"
+ "add x13, %[input_ptr], x9\n"
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "add x14, x13, x15\n"
+ "ld1 {v9.8b}, [x12], #8\n"
+ "ldr x6, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n"
+
+ "add x9, %[filter_ptr], x15\n"
+ "ld1 {v10.8b}, [x13], #8\n"
+ "add x10, %[filter_ptr], x6\n"
+ "ld1 {v11.8b}, [x14], #8\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+ "add x11, x10, x15\n"
+ "ld1 {v1.8b}, [x9], #8\n"
+ "ld1 {v2.8b}, [x10], #8\n"
+ "ld1 {v3.8b}, [x11], #8\n"
+
+ // Load constants.
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w6\n"
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w7\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
+ "dup v28.4s, w6\n"
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w7, w7\n"
+ "dup v29.4s, w7\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w6\n"
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w7\n"
+ "dup v25.8h, w6\n"
+
+ // Add input and filter offsets.
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "subs x15, x15, #8\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "cmp x15, #16\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], #8\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "ld1 {v1.8b}, [x9], #8\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], #8\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "ld1 {v2.8b}, [x10], #8\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x14], #8\n"
+ "ld1 {v3.8b}, [x11], #8\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18",
+ "v19", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
+ // We use these general-purpose registers.
+ "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
}
};
template <>
-struct ConvKernel3x3FilterDepth8<1, 2, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 3 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- DotProductAndStore(
- filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8,
- input_6, input_7, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
+struct DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 2x3 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the beginning of the 2x3 input and
+ // filter values.
+
+ // Load input and filter values.
+ "ldr x7, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n"
+ "mov x12, %[input_ptr]\n"
+ "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n"
+ "mov x9, %[filter_ptr]\n"
+ "ldr x14, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n"
+ "add x13, x12, x11\n"
+ "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+
+ "ld1 {v8.8b}, [x12], x7\n"
+ "add x10, x9, x14\n"
+ "ld1 {v9.8b}, [x12], x7\n"
+ "cmp x15, #16\n"
+ "ld1 {v10.8b}, [x12]\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+ "ld1 {v11.8b}, [x13], x7\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "ld1 {v12.8b}, [x13], x7\n"
+ "ld1 {v13.8b}, [x13]\n"
+
+ "ld1 {v0.8b}, [x9], x7\n"
+ "ld1 {v1.8b}, [x9], x7\n"
+ "ld1 {v2.8b}, [x9]\n"
+ "ld1 {v3.8b}, [x10], x7\n"
+ "ld1 {v4.8b}, [x10], x7\n"
+ "ld1 {v5.8b}, [x10]\n"
+
+ // Load constants.
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
+ "dup v28.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w13, w13\n"
+ "dup v29.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w13\n"
+ "dup v25.8h, w12\n"
+
+ // Add input and filter offsets.
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "mov x12, %[input_ptr]\n"
+ "subs x15, x15, #8\n"
+ "add x13, x12, x11\n"
+ "cmp x15, #16\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "mov x9, %[filter_ptr]\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [x12], x7\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "add x10, x9, x14\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], x7\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x12]\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "ld1 {v0.8b}, [x9], x7\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13], x7\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "ld1 {v1.8b}, [x9], x7\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13], x7\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "ld1 {v2.8b}, [x9]\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13]\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "ld1 {v3.8b}, [x10], x7\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "ld1 {v4.8b}, [x10], x7\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "ld1 {v5.8b}, [x10]\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
}
};
template <>
-struct ConvKernel3x3FilterDepth8<1, 4, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 3 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- DotProductAndStore(
- filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8,
- input_6, input_7, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Third output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 5 * input_depth;
- temp_2 = vld1_u8(ptr);
- temp_0 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_5 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_8 = vld1_u8(ptr);
- temp_6 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
-
- DotProductAndStore(
- filter, input_1, input_2, input_0, input_4, input_5, input_3, input_7,
- input_8, input_6, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Fourth output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 7 * input_depth;
- temp_1 = vld1_u8(ptr);
- temp_2 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_7 = vld1_u8(ptr);
- temp_8 = vld1_u8(ptr + input_depth);
-
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
+struct DepthwiseConvPartial<EdgeType::kVertical, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 3x2 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the beginning of the 3x2 input and
+ // filter values.
+
+ // Load input and filter values.
+ "ldr x6, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n"
+ "mov x12, %[input_ptr]\n"
+ "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n"
+ "mov x7, %[filter_ptr]\n"
+ "ldr x5, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n"
+ "add x13, x12, x11\n"
+ "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "add x14, x13, x11\n"
+
+ "ld1 {v8.8b}, [x12], x6\n"
+ "add x9, x7, x5\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "cmp x15, #16\n"
+ "add x10, x9, x5\n"
+ "ld1 {v10.8b}, [x13], x6\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+ "ld1 {v11.8b}, [x13]\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "ld1 {v12.8b}, [x14], x6\n"
+ "ld1 {v13.8b}, [x14]\n"
+
+ "ld1 {v0.8b}, [x7], x6\n"
+ "ld1 {v1.8b}, [x7]\n"
+ "ld1 {v2.8b}, [x9], x6\n"
+ "ld1 {v3.8b}, [x9]\n"
+ "ld1 {v4.8b}, [x10], x6\n"
+ "ld1 {v5.8b}, [x10]\n"
+
+ // Load constants.
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
+ "dup v28.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w13, w13\n"
+ "dup v29.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w13\n"
+ "dup v25.8h, w12\n"
+
+ // Add input and filter offsets.
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "mov x12, %[input_ptr]\n"
+ "subs x15, x15, #8\n"
+ "add x13, x12, x11\n"
+ "cmp x15, #16\n"
+ "add x14, x13, x11\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "mov x7, %[filter_ptr]\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [x12], x6\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "add x9, x7, x5\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "add x10, x9, x5\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], x6\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "ld1 {v0.8b}, [x7], x6\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13]\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "ld1 {v1.8b}, [x7]\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x14], x6\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "ld1 {v2.8b}, [x9], x6\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x14]\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "ld1 {v3.8b}, [x9]\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "ld1 {v4.8b}, [x10], x6\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "ld1 {v5.8b}, [x10]\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x5", "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
}
};
-template <int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
-
- uint8x8_t temp_0 = vld1_u8(input_ptr);
- uint8x8_t temp_1 = vld1_u8(input_ptr + input_depth);
- uint8x8_t temp_2 = vld1_u8(input_ptr + 2 * input_depth);
-
- input_ptr += input_row_size;
- uint8x8_t temp_3 = vld1_u8(input_ptr);
- uint8x8_t temp_4 = vld1_u8(input_ptr + input_depth);
- uint8x8_t temp_5 = vld1_u8(input_ptr + 2 * input_depth);
-
- input_ptr += input_row_size;
- uint8x8_t temp_6 = vld1_u8(input_ptr);
- uint8x8_t temp_7 = vld1_u8(input_ptr + input_depth);
- uint8x8_t temp_8 = vld1_u8(input_ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
- }
-};
-
-inline void ShuffleInput(const uint8* input_ptr, int input_depth,
- int input_width, int input_height, int output_depth,
- int output_width, int output_height,
- uint8* output_ptr) {
- const int input_row_size = input_depth * input_width;
-
- for (int y = 0; y < output_height; y++) {
+#undef OFFSET_INPUT_DEPTH
+#undef OFFSET_INPUT_ROW_SIZE
+#undef OFFSET_OUTPUT_DEPTH
+#undef OFFSET_OUTPUT_ROW_SIZE
+#undef OFFSET_INPUT_OFFSET
+#undef OFFSET_OUTPUT_OFFSET
+#undef OFFSET_FILTER_OFFSET
+#undef OFFSET_OUTPUT_MULTIPLIER
+#undef OFFSET_OUTPUT_ACTIVATION_MIN
+#undef OFFSET_OUTPUT_ACTIVATION_MAX
+#undef OFFSET_OUTPUT_RIGHT_SHIFT
+#undef OFFSET_INPUT_WIDTH
+#undef OFFSET_INPUT_HEIGHT
+#undef OFFSET_OUTPUT_WIDTH
+#undef OFFSET_OUTPUT_HEIGHT
+#undef STR
+#undef STR_UNEXPANDED
+
+// Copies a subset of the input designated by |input_ptr| into |output_ptr|
+// with the specified output dimensions. Supports output depths of 64 only as
+// this is the cache line size.
+inline void ShuffleInput(const uint8* input_ptr, int64_t input_depth,
+ int32 input_width, int32 input_height,
+ int64_t output_depth, int32 output_width,
+ int32 output_height, uint8* output_ptr) {
+ const int64_t input_row_size = input_depth * input_width;
+ for (int32 y = 0; y < output_height; y++) {
const uint8* ptr = input_ptr;
- for (int x = 0; x < output_width; x++) {
+ for (int32 x = 0; x < output_width; x++) {
memcpy(output_ptr, ptr, output_depth);
output_ptr += output_depth;
ptr += input_depth;
@@ -3873,561 +2937,265 @@ inline void ShuffleInput(const uint8* input_ptr, int input_depth,
}
}
-template <int kFixedHeight, int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvRow3x3FilterDepth8 {};
-
-template <int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvRow3x3FilterDepth8<1, kFixedStrideWidth, kFixedStrideHeight> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
-
- // 1x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<1, 4, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * kFixedStrideWidth * input_depth;
- output_data += 4 * output_depth;
- }
-
- // 1x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
+// Calculates the input size depending on stride and output.
+inline int32 get_shuffle_input_size(int32 stride, int32 output) {
+ return stride * (output - 1) + 3;
+}
- input_data += kFixedStrideWidth * input_depth;
- output_data += output_depth;
- }
+// Indicates the input and output dimensions used when shuffling input
+// activations.
+struct ShuffleParams {
+ int32 output_width;
+ int32 output_height;
+ int32 input_width;
+ int32 input_height;
+
+ ShuffleParams() = default;
+ ShuffleParams(int32 output_width, int32 output_height, int32 stride_width,
+ int32 stride_height)
+ : output_width(output_width)
+ , output_height(output_height)
+ , input_width(get_shuffle_input_size(stride_width, output_width))
+ , input_height(get_shuffle_input_size(stride_height, output_height)) {
}
};
-template <int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvRow3x3FilterDepth8<2, kFixedStrideWidth, kFixedStrideHeight> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
-
- // 2x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<2, 4, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * kFixedStrideWidth * input_depth;
- output_data += 4 * output_depth;
- }
-
- // 2x2 at a time.
- for (; out_x <= output_width - 2; out_x += 2) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<2, 2, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 2 * kFixedStrideWidth * input_depth;
- output_data += 2 * output_depth;
- }
-
- // 2x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<2, 1, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += kFixedStrideWidth * input_depth;
- output_data += output_depth;
+template <int32 kStrideWidth, int32 kStrideHeight>
+struct DepthwiseConvThroughDepth {
+ // Runs the DepthwiseConvWindow kernels through the depth dimension from
+ // |start_depth| to |end_depth|. Keep this not inlined to maintain a small
+ // binary size. We use a DepthwiseConvParams struct for read only params
+ // to minimize call overhead.
+ static __attribute__((noinline)) void Run(const uint8* input_ptr,
+ const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr,
+ int64_t start_depth, int64_t end_depth, int64_t input_depth,
+ int64_t input_row_size, int32 output_window_height,
+ int32 output_window_width, const DepthwiseConvParams& params) {
+ for (; start_depth <= end_depth - 8; start_depth += 8) {
+ DepthwiseConvWindow<8, kStrideWidth, kStrideHeight>::Run(
+ input_ptr, filter_ptr, bias_ptr, output_ptr, input_depth,
+ input_row_size, output_window_height, output_window_width, &params);
+ input_ptr += 8;
+ output_ptr += 8;
+ filter_ptr += 8;
+ bias_ptr += 8;
}
}
};
-template <>
-struct ConvRow3x3FilterDepth8<4, 1, 1> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
-
- // 4x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 4, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
+template <int32 kStrideWidth, int32 kStrideHeight>
+struct DepthwiseConvMultiRow {
+ using ConvKernel = DepthwiseConvThroughDepth<kStrideWidth, kStrideHeight>;
- input_data += 4 * input_depth;
- output_data += 4 * output_depth;
- }
-
- // Handle the rest of the right side.
- // 4x2 at a time.
- for (; out_x <= output_width - 2; out_x += 2) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 2, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 2 * input_depth;
- output_data += 2 * output_depth;
- }
-
- // 4x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 1, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += input_depth;
- output_data += output_depth;
- }
- }
-};
-
-template <>
-struct ConvRow3x3FilterDepth8<4, 2, 2> {
- // The buffer size of the shuffled input.
- static inline constexpr int ShuffleWorkspaceSize() { return 64 * 9 * 9; }
-
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
+ static inline void Run(const uint8* input_data, int32 start_x, int32 end_x,
+ const uint8* filter_data, const int32* bias_data,
+ uint8* output_data, const DepthwiseConvParams& params,
+ const ShuffleParams& shuffle_params,
uint8* shuffle_workspace) {
- // Branch and cache misses increase substantially with stride 2 kernels.
- // Adding prefetching reduces latency by as much as 2x.
- const int i0 = 0;
- const int i1 = input_depth;
- const int i2 = 2 * input_depth;
- const int i3 = 3 * input_depth;
- const int i4 = 4 * input_depth;
- const int i5 = 5 * input_depth;
- const int i6 = 6 * input_depth;
- const int i7 = 7 * input_depth;
- const int i8 = 8 * input_depth;
-
-#define DEPTHWISECONV_PRELOAD_ROW(input_ptr, i) \
- preload_l1_keep(input_ptr + i * input_row_size + i0); \
- preload_l1_keep(input_ptr + i * input_row_size + i1); \
- preload_l1_keep(input_ptr + i * input_row_size + i2); \
- preload_l1_keep(input_ptr + i * input_row_size + i3); \
- preload_l1_keep(input_ptr + i * input_row_size + i4); \
- preload_l1_keep(input_ptr + i * input_row_size + i5); \
- preload_l1_keep(input_ptr + i * input_row_size + i6); \
- preload_l1_keep(input_ptr + i * input_row_size + i7); \
- preload_l1_keep(input_ptr + i * input_row_size + i8);
-
- int out_x = start_x;
- // 4x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- int depth = 0;
- for (; depth <= output_depth - 64; depth += 64) {
- // Preload 9x9 input.
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8);
-
- // For a large input window (64x9x9) that is small enough to fit in L1
- // cache, copy the input into a separate buffer and run the kernel on
- // this new buffer. This reduces the likelihood of cache misses when
- // the kernel is loading input data. If this size is ever changed,
- // update the ShuffleWorkspaceSize() function to return the new size.
- ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 9,
- 9, shuffle_workspace);
- const uint8* shuffled_ptr = &shuffle_workspace[0];
-
- for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) {
- ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run(
- shuffled_ptr, 64, input_offset, 64 * 9, filter_ptr, filter_offset,
- bias_ptr, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr,
- output_depth, output_width);
-
- shuffled_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
+ TFLITE_DCHECK(shuffle_params.input_height ==
+ get_shuffle_input_size(kStrideHeight, shuffle_params.output_height));
+ TFLITE_DCHECK(shuffle_params.input_width ==
+ get_shuffle_input_size(kStrideWidth, shuffle_params.output_width));
+ TFLITE_DCHECK(64 * shuffle_params.input_width * shuffle_params.input_height
+ <= DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE);
+
+ int32 out_x = start_x;
+
+ // Run shuffling on inputs with sufficiently large depth and width. When
+ // these parameters are large enough, more time is taken to load inputs
+ // from memory. At this point, it becomes useful to prefetch and
+ // preshuffle the input data to maximize locality.
+ if (params.output_depth > 64 ||
+ (params.output_depth <= 64 && params.input_width > 150)) {
+ for (; out_x <= (end_x - shuffle_params.output_width);
+ out_x += shuffle_params.output_width) {
+ const uint8* input_ptr = input_data;
+ const int32* bias_ptr = bias_data;
+ const uint8* filter_ptr = filter_data;
+ uint8* output_ptr = output_data;
+ int64_t depth = 0;
+ const int64_t shuffle_row_size = 64 * shuffle_params.input_width;
+
+ for (; depth <= params.output_depth - 64; depth += 64) {
+ // Preload.
+ const uint8* h_ptr = input_ptr;
+ for (int32 i = 0; i < shuffle_params.input_height; i++) {
+ const uint8* ptr = h_ptr;
+ for (int32 j = 0; j < shuffle_params.input_width; j++) {
+ asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+ ptr += params.input_depth;
+ }
+ h_ptr += params.input_row_size;
+ }
+
+ // For a large enough input, shuffle into buckets.
+ ShuffleInput(input_ptr, params.input_depth, params.input_width,
+ params.input_height, 64, shuffle_params.input_width,
+ shuffle_params.input_height, shuffle_workspace);
+ ConvKernel::Run(shuffle_workspace, filter_ptr, bias_ptr, output_ptr,
+ 0, 64, 64, shuffle_row_size,
+ shuffle_params.output_height,
+ shuffle_params.output_width, params);
+ input_ptr += 64;
+ output_ptr += 64;
+ filter_ptr += 64;
+ bias_ptr += 64;
}
- input_ptr += 64;
- }
-
- // Preload 9x9 input one more time for the rest of the depth.
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8);
-
- for (; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * 2 * input_depth;
- output_data += 4 * output_depth;
- }
-
-#undef DEPTHWISECONV_PRELOAD_ROW
-
- // Handle the rest of the right side.
- // 4x2 at a time.
- for (; out_x <= output_width - 2; out_x += 2) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
+ // Preload.
+ const uint8* h_ptr = input_ptr;
+ for (int32 i = 0; i < shuffle_params.input_height; i++) {
+ const uint8* ptr = h_ptr;
+ for (int32 j = 0; j < shuffle_params.input_width; j++) {
+ asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+ ptr += params.input_depth;
+ }
+ h_ptr += params.input_row_size;
+ }
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
+ // Handle leftover depth.
+ ConvKernel::Run(input_ptr, filter_ptr, bias_ptr, output_ptr,
+ depth, params.output_depth, params.input_depth,
+ params.input_row_size, shuffle_params.output_height,
+ shuffle_params.output_width, params);
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
+ input_data +=
+ shuffle_params.output_width * kStrideWidth * params.input_depth;
+ output_data += shuffle_params.output_width * params.output_depth;
}
-
- input_data += 2 * 2 * input_depth;
- output_data += 2 * output_depth;
}
- // 4x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 1, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 2 * input_depth;
- output_data += output_depth;
+ const int32 output_leftover_width = end_x - out_x;
+ if (output_leftover_width > 0) {
+ ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0,
+ params.output_depth, params.input_depth,
+ params.input_row_size, shuffle_params.output_height,
+ output_leftover_width, params);
}
}
};
-template <>
-struct ConvRow3x3FilterDepth8<8, 2, 2> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- // Reuse 4 row kernels twice.
- ConvRow3x3FilterDepth8<4, 2, 2>::Run(
- input_data, start_x, start_y, input_depth, input_width, input_height,
- input_row_size, input_offset, filter_data, filter_offset, bias_data,
- output_offset, output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_depth, output_width,
- shuffle_workspace);
-
- ConvRow3x3FilterDepth8<4, 2, 2>::Run(
- input_data + 2 * 4 * input_row_size, start_x, start_y + 4, input_depth,
- input_width, input_height, input_row_size, input_offset, filter_data,
- filter_offset, bias_data, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data + 4 * output_depth * output_width, output_depth,
- output_width, shuffle_workspace);
+// Processes the borders of the input for pad_width and pad_height = 1.
+// Calls 4 asm kernels:
+// * 1x1 input shape.
+// * Corner edges.
+// * Horizontal edges.
+// * Vertical edges.
+inline void DepthwiseConvHandlePadding(const uint8* input_data,
+ const uint8* filter_data, const int32* bias_data, uint8* output_data,
+ const DepthwiseConvParams& params) {
+ if (params.input_width == 1 && params.input_height == 1) {
+ const uint8* filter_ptr = filter_data + params.filter_row_size
+ + params.output_depth;
+ DepthwiseConvPartial<EdgeType::kCenter, 1, 1>::Run(input_data, filter_ptr,
+ bias_data, output_data, &params);
+ return;
}
-};
-template <>
-struct ConvRow3x3FilterDepth8<8, 1, 1> {
- // The buffer size of the shuffled input.
- static inline constexpr int ShuffleWorkspaceSize() { return 64 * 10 * 10; }
-
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
- // 8x8 at a time.
- for (; out_x <= output_width - 8; out_x += 8) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- int depth = 0;
- for (; depth <= output_depth - 64; depth += 64) {
- // For a large input window (64x10x10) that is small enough to fit in L1
- // cache, copy the input into a separate buffer and run the kernel on
- // this new buffer. This reduces the likelihood of cache misses when
- // the kernel is loading input data. If the size of the input window
- // changes, update the function ShuffleWorkspaceSize() with the new
- // size.
- ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 10,
- 10, shuffle_workspace);
- const uint8* shuffled_ptr = shuffle_workspace;
-
- for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) {
- ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run(
- shuffled_ptr, 64, input_offset, 64 * 10, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- shuffled_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
- input_ptr += 64;
- }
+ const int32 out_x_start_corner = 0;
+ const int32 out_x_end_corner = params.output_width - 1;
+ const int32 out_y_start_corner = 0;
+ const int32 out_y_end_corner = params.output_height - 1;
+
+ // Handle top row.
+ const uint8* input_ptr = input_data;
+ const uint8* filter_ptr = filter_data + params.filter_row_size
+ + params.output_depth;
+ uint8* output_ptr = output_data;
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+
+ input_ptr += (params.stride_width - 1) * params.input_depth;
+ filter_ptr = filter_data + params.filter_row_size;
+ output_ptr += params.output_depth;
+
+ for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner;
+ out_x++) {
+ DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_depth;
+ output_ptr += params.output_depth;
+ }
- for (; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
- input_data += 8 * input_depth;
- output_data += 8 * output_depth;
- }
+ // Handle left side.
+ input_ptr = input_data + (params.stride_width - 1) * params.input_row_size;
+ filter_ptr = filter_data + params.input_depth;
+ output_ptr = output_data + params.output_row_size;
- // Handle the rest of the right side by re-using 4 row kernels twice.
- ConvRow3x3FilterDepth8<4, 1, 1>::Run(
- input_data, out_x, start_y, input_depth, input_width, input_height,
- input_row_size, input_offset, filter_data, filter_offset, bias_data,
- output_offset, output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_depth, output_width,
- shuffle_workspace);
-
- ConvRow3x3FilterDepth8<4, 1, 1>::Run(
- input_data + 4 * input_row_size, out_x, start_y + 4, input_depth,
- input_width, input_height, input_row_size, input_offset, filter_data,
- filter_offset, bias_data, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data + 4 * output_depth * output_width, output_depth,
- output_width, shuffle_workspace);
+ for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner;
+ out_y++) {
+ DepthwiseConvPartial<EdgeType::kVertical, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_row_size;
+ output_ptr += params.output_row_size;
}
-};
-inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims,
- const Dims<4>& filter_dims,
- int stride_width, int stride_height,
- int pad_width, int pad_height,
- int depth_multiplier,
- const Dims<4>& output_dims) {
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
-
- bool supported = filter_width == 3 && filter_height == 3 &&
- depth_multiplier == 1 &&
- (stride_width == 1 || stride_width == 2) &&
- (stride_height == 1 || stride_height == 2) &&
- (stride_width == stride_height) && pad_width == 0 &&
- pad_height == 0 && (input_depth % 8) == 0;
+ // Handle right side.
+ input_ptr = input_data + (params.input_width - 2) * params.input_depth
+ + (params.stride_width - 1) * params.input_row_size;
+ filter_ptr = filter_data;
+ output_ptr = output_data + params.output_row_size +
+ (params.output_width - 1) * params.output_depth;
+
+ for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner;
+ out_y++) {
+ DepthwiseConvPartial<EdgeType::kVertical, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_row_size;
+ output_ptr += params.output_row_size;
+ }
+
+ // Handle bottom row.
+ input_ptr = input_data + (params.input_height - 2) * params.input_row_size;
+ filter_ptr = filter_data + params.output_depth;
+ output_ptr = output_data +
+ (params.output_height - 1) * params.output_row_size;
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+
+ input_ptr += (params.stride_width == 1) ? 0 : params.input_depth;
+ filter_ptr = filter_data;
+ output_ptr += params.output_depth;
+
+ for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner;
+ out_x++) {
+ DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_depth;
+ output_ptr += params.output_depth;
+ }
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+}
+
+inline bool Fast3x3FilterKernelSupported(
+ const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
+ int32 stride_width, int32 stride_height, int32 dilation_width_factor,
+ int32 dilation_height_factor, int32 pad_width, int32 pad_height,
+ int32 depth_multiplier, const RuntimeShape& output_shape,
+ int32 output_shift) {
+ const int32 input_height = input_shape.Dims(1);
+ const int32 input_width = input_shape.Dims(2);
+ const int32 input_depth = input_shape.Dims(3);
+ const int32 filter_height = filter_shape.Dims(1);
+ const int32 filter_width = filter_shape.Dims(2);
+ const int32 output_height = output_shape.Dims(1);
+ const int32 output_width = output_shape.Dims(2);
+
+ bool supported =
+ filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
+ (stride_width == 1 || stride_width == 2) &&
+ (stride_height == 1 || stride_height == 2) &&
+ (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
+ (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
+ (input_depth % 8) == 0 && (output_shift <= 0) &&
+ dilation_width_factor == 1 && dilation_height_factor == 1;
if (!supported) {
return false;
@@ -4436,145 +3204,205 @@ inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims,
// Handle case where padding is zero but padding type is not kValid.
// This would require special boundary case handling that is not supported.
- const int out_x = output_width - 1;
- const int out_y = output_height - 1;
+ const int32 out_x = output_width - 1;
+ const int32 out_y = output_height - 1;
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int32 in_x_origin = (out_x * stride_width) - pad_width;
+ const int32 in_y_origin = (out_y * stride_height) - pad_height;
- const int in_x_end = in_x_origin + filter_width;
- const int in_y_end = in_y_origin + filter_height;
+ const int32 in_x_end = in_x_origin + filter_width;
+ const int32 in_y_end = in_y_origin + filter_height;
// Supported only if filter on the right and bottom boundary lies completely
- // within the input.
- return in_x_end <= input_width && in_y_end <= input_height;
+ // within the input if padding is zero.
+ if (pad_width == 0 && pad_height == 0) {
+ return in_x_end <= input_width && in_y_end <= input_height;
+ }
+
+ // Else if padding is 1, supported if bottom right filter lies +1 past input
+ // width and height.
+ supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1);
+
+ if (!supported) {
+ return false;
+ }
+
+ // Shapes with width 1 and height > 1, and vice versa are not supported yet.
+ if (input_width == 1) {
+ supported = (input_width == input_height);
+ } else if (input_height == 1) {
+ supported = (input_width == input_height);
+ }
+ return supported;
}
inline void DepthwiseConv3x3Filter(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
-
- // Algorithm assumes below constraints. It is optimized for depth multiplier
- // of 1, 3x3 filter, no padding and strides 1 and 2.
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ const DepthwiseParams& rt_params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
+ DepthwiseConvParams params;
+
+ const int32 stride_width = rt_params.stride_width;
+ const int32 stride_height = rt_params.stride_height;
+ const int32 pad_width = rt_params.padding_values.width;
+ const int32 pad_height = rt_params.padding_values.height;
+ const int32 depth_multiplier = rt_params.depth_multiplier;
+ const int32 output_activation_min = rt_params.quantized_activation_min;
+ const int32 output_activation_max = rt_params.quantized_activation_max;
+ const int32 input_offset = rt_params.input_offset;
+ const int32 filter_offset = rt_params.weights_offset;
+ const int32 output_offset = rt_params.output_offset;
+ const int32 output_multiplier = rt_params.output_multiplier;
+ const int32 output_shift = rt_params.output_shift;
+
+ params.input_depth = input_shape.Dims(3);
+ params.input_width = input_shape.Dims(2);
+ params.input_height = input_shape.Dims(1);
+ params.input_row_size = params.input_depth * params.input_width;
+ params.input_offset = input_offset;
+ params.stride_width = stride_width;
+ params.stride_height = stride_height;
+ params.output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ params.output_width = output_shape.Dims(2);
+ params.output_height = output_shape.Dims(1);
+ params.output_row_size = params.output_depth * params.output_width;
+ params.output_offset = output_offset;
+ params.filter_offset = filter_offset;
+ params.output_multiplier = output_multiplier;
+ params.output_right_shift = -output_shift;
+ params.output_activation_min = output_activation_min;
+ params.output_activation_max = output_activation_max;
+
+ const int32 filter_height = filter_shape.Dims(1);
+ const int32 filter_width = filter_shape.Dims(2);
+ params.filter_row_size = params.output_depth * filter_width;
+
+ // Algorithm assumes below constraints. It is optimized for depth
+ // multiplier of 1, 3x3 filter, no padding and strides 1 and 2.
+ TFLITE_DCHECK(params.output_depth == params.input_depth * depth_multiplier);
TFLITE_DCHECK(depth_multiplier == 1);
TFLITE_DCHECK(filter_height == 3);
TFLITE_DCHECK(filter_width == 3);
- TFLITE_DCHECK(pad_height == 0);
- TFLITE_DCHECK(pad_width == 0);
TFLITE_DCHECK(stride_height == 1 || stride_height == 2);
TFLITE_DCHECK(stride_width == 1 || stride_width == 2);
TFLITE_DCHECK(stride_width == stride_height);
+ TFLITE_DCHECK(pad_height == 0 || pad_height == 1);
+ TFLITE_DCHECK(pad_width == 0 || pad_width == 1);
+ TFLITE_DCHECK(pad_width == pad_height);
+
+ const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int64_t input_batch_size = params.input_row_size * params.input_height;
+ const int64_t output_batch_size =
+ params.output_row_size * params.output_height;
+
+ ShuffleParams one_row_shuffle_params, two_row_shuffle_params,
+ four_row_shuffle_params, eight_row_shuffle_params;
+ if (stride_width == 1) {
+ one_row_shuffle_params = ShuffleParams(30, 1, 1, 1);
+ two_row_shuffle_params = ShuffleParams(22, 2, 1, 1);
+ four_row_shuffle_params = ShuffleParams(14, 4, 1, 1);
+ eight_row_shuffle_params = ShuffleParams(8, 8, 1, 1);
+ } else {
+ one_row_shuffle_params = ShuffleParams(14, 1, 2, 2);
+ two_row_shuffle_params = ShuffleParams(8, 2, 2, 2);
+ four_row_shuffle_params = ShuffleParams(4, 4, 2, 2);
+ eight_row_shuffle_params = ShuffleParams(2, 8, 2, 2);
+ }
- const int input_row_size = input_depth * (input_width + 2 * pad_width);
- const int output_row_size = output_depth * output_width;
- const int input_batch_size = input_row_size * (input_height + 2 * pad_height);
- const int output_batch_size = output_depth * output_width * output_height;
-
- using conv_row_func_t = decltype(&ConvRow3x3FilterDepth8<1, 1, 1>::Run);
- conv_row_func_t conv_1_output_row = ConvRow3x3FilterDepth8<1, 1, 1>::Run;
- conv_row_func_t conv_2_output_rows = ConvRow3x3FilterDepth8<2, 1, 1>::Run;
- conv_row_func_t conv_4_output_rows = ConvRow3x3FilterDepth8<4, 1, 1>::Run;
- conv_row_func_t conv_8_output_rows = ConvRow3x3FilterDepth8<8, 1, 1>::Run;
-
+ using conv_multirow_func_t = decltype(&DepthwiseConvMultiRow<1, 1>::Run);
+ conv_multirow_func_t conv_multirow_func = DepthwiseConvMultiRow<1, 1>::Run;
if (stride_width == 2) {
- conv_1_output_row = ConvRow3x3FilterDepth8<1, 2, 2>::Run;
- conv_2_output_rows = ConvRow3x3FilterDepth8<2, 2, 2>::Run;
- conv_4_output_rows = ConvRow3x3FilterDepth8<4, 2, 2>::Run;
- conv_8_output_rows = ConvRow3x3FilterDepth8<8, 2, 2>::Run;
+ conv_multirow_func = DepthwiseConvMultiRow<2, 2>::Run;
}
// Allocate maximum memory needed for shuffled input.
// TODO(mariewhite): The size of this workspace is small enough to be
// allocated on the stack. Eventually we will want to move it to the heap
- // and have it allocated outside of this function, like the im2col_array used
- // in gemmlowp.
-#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64
+ // and have it allocated outside of this function, like the im2col_array
+ // used in gemmlowp.
uint8 shuffle_workspace[DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE];
- // Make sure the kernels using this buffer will not run out of bounds.
- static_assert(ConvRow3x3FilterDepth8<8, 1, 1>::ShuffleWorkspaceSize() <=
- DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE,
- "Shuffle workspace size is too small.");
- static_assert(ConvRow3x3FilterDepth8<4, 2, 2>::ShuffleWorkspaceSize() <=
- DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE,
- "Shuffle workspace size is too small.");
-
-#undef DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE
-
- for (int b = 0; b < batches; ++b) {
+ for (int32 b = 0; b < batches; ++b) {
const uint8* input_ptr = input_data + b * input_batch_size;
uint8* output_ptr = output_data + b * output_batch_size;
- int out_y = 0;
+ int32 out_x = 0;
+ int32 out_y = 0;
+ int32 end_x = params.output_width;
+ int32 end_y = params.output_height;
+
+ if (pad_width == 1 && pad_height == 1) {
+ DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr,
+ params);
+
+ // Update extents now that the edges have been handled.
+ out_x = 1;
+ end_x = params.output_width - 1;
+ out_y = 1;
+ end_y = params.output_height - 1;
+ const int in_x = (out_x * stride_width) - pad_width;
+ const int in_y = (out_y * stride_height) - pad_height;
+ input_ptr += in_y * params.input_row_size + in_x * params.input_depth;
+ output_ptr += out_y * params.output_row_size
+ + out_x * params.output_depth;
+ }
+
+ // Shuffling shapes that maximize width over the shuffle workspace size
+ // perform better since the inputs are closer together, minimizing
+ // shuffling time.
+ //
+ // If the input shape has width large enough for the 2 row kernels,
+ // we prefer to use this. The innermost loop of the kernels handle
+ // 2 height x 2 width so this is the fastest path.
+ //
+ // If the input shape has smaller width but larger height, shuffling is
+ // still useful and can benefit from kernels 4 row and 8 row kernels.
// Handle 8 rows at a time.
- for (; out_y <= output_height - 8; out_y += 8) {
- conv_8_output_rows(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset,
- filter_data, filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += 8 * stride_height * input_row_size;
- output_ptr += 8 * output_row_size;
+ if (params.input_width < four_row_shuffle_params.input_width) {
+ for (; out_y <= end_y - 8; out_y += 8) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
+ output_ptr, params, eight_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += 8 * stride_height * params.input_row_size;
+ output_ptr += 8 * params.output_row_size;
+ }
}
// Handle 4 rows at a time.
- for (; out_y <= output_height - 4; out_y += 4) {
- conv_4_output_rows(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset,
- filter_data, filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += 4 * stride_height * input_row_size;
- output_ptr += 4 * output_row_size;
+ if (params.input_width < two_row_shuffle_params.input_width) {
+ for (; out_y <= end_y - 4; out_y += 4) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
+ output_ptr, params, four_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += 4 * stride_height * params.input_row_size;
+ output_ptr += 4 * params.output_row_size;
+ }
}
// Handle 2 rows at a time.
- for (; out_y <= output_height - 2; out_y += 2) {
- conv_2_output_rows(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset,
- filter_data, filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += 2 * stride_height * input_row_size;
- output_ptr += 2 * output_row_size;
+ for (; out_y <= end_y - 2; out_y += 2) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
+ output_ptr, params, two_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += 2 * stride_height * params.input_row_size;
+ output_ptr += 2 * params.output_row_size;
}
// Handle one row at a time.
- for (; out_y < output_height; out_y++) {
- conv_1_output_row(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset, filter_data,
- filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += stride_height * input_row_size;
- output_ptr += output_row_size;
+ for (; out_y < end_y; out_y++) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
+ output_ptr, params, one_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += stride_height * params.input_row_size;
+ output_ptr += params.output_row_size;
}
}
}
+// clang-format on
#endif // __aarch64__
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
index d85e06a5d5..6443f425b7 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
@@ -33,7 +33,7 @@ limitations under the License.
#include <functional>
#ifdef _WIN32
-#include <winbase.h>
+#include <windows.h>
#elif defined(__APPLE__)
#include <mach/mach_time.h>
#else
@@ -140,4 +140,4 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
new file mode 100644
index 0000000000..4218be20a4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -0,0 +1,1872 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Unoptimized reference ops:
+using reference_ops::ArgMax;
+using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
+using reference_ops::BroadcastAdd4DSlow;
+using reference_ops::BroadcastGreater;
+using reference_ops::BroadcastGreaterEqual;
+using reference_ops::BroadcastLess;
+using reference_ops::BroadcastLessEqual;
+using reference_ops::BroadcastMul4DSlow;
+using reference_ops::BroadcastSub4DSlow;
+using reference_ops::Concatenation;
+using reference_ops::ConcatenationWithScaling;
+using reference_ops::DepthConcatenation;
+using reference_ops::Dequantize;
+using reference_ops::Div;
+using reference_ops::FakeQuant;
+using reference_ops::Gather;
+using reference_ops::Greater;
+using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
+using reference_ops::Less;
+using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
+using reference_ops::Mean;
+using reference_ops::RankOneSelect;
+using reference_ops::Relu1;
+using reference_ops::Relu6;
+using reference_ops::ReluX;
+using reference_ops::Select;
+using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
+using reference_ops::StridedSlice;
+using reference_ops::TensorFlowSplit;
+using reference_ops::Transpose;
+
+static constexpr int kDepthwiseReverseShift = -1;
+
+template <typename Scalar, int N>
+VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
+ const int size = FlatSize(dims);
+ return VectorMap<Scalar>(data, size, 1);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return ArrayMap<Scalar>(data, rows, cols);
+}
+
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const Dims<N>& dims,
+ int rows) {
+ const int flatsize = FlatSize(dims);
+ TFLITE_DCHECK((flatsize % rows) == 0);
+ const int cols = flatsize / rows;
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
+ for (int i = 0; i < 4; i++) {
+ if (dims1.sizes[i] != dims2.sizes[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(array_dims), array_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
+ output_activation_min,
+ output_activation_max);
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+template <typename T>
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+ ExtractPatchIntoBufferColumn(
+ DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+ stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+ single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+
+ DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 zero_byte, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+
+ Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+ input_data, DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, zero_byte, output_data, output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+ input_data, DimsToShape(filter_dims), filter_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float* output_data, const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, dilation_width_factor,
+ dilation_height_factor, pad_width, pad_height, output_activation_min,
+ output_activation_max, output_data, output_dims, im2col_data,
+ im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, 1, 1, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, zero_byte, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
+
+ const auto input_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
+ output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ const int input_rows = input_dims.sizes[0];
+ const int input_cols = FlatSizeSkipDim(input_dims, 0);
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, output_cols, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
+ bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float, but reserved in signature for future
+ // activations.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), beta, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
+ input_beta_left_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
+ input_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+ int input_left_shift, int16* output_data,
+ const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32 output_activation_min, int32 output_activation_max,
+ int32* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.output_offset = output_offset;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// For compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ float float_activation_min;
+ float float_activation_max;
+ GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
+ SetActivationParams(float_activation_min, float_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 27d9224512..4139cf4eba 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_
#include <assert.h>
#include <stdint.h>
@@ -26,7 +26,7 @@ limitations under the License.
#include <tuple>
#include <type_traits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
@@ -35,35 +35,6 @@ limitations under the License.
namespace tflite {
namespace multithreaded_ops {
-class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
- public:
- explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
- ~EigenThreadPoolWrapper() override {}
-
- void Schedule(std::function<void()> fn) override {
- pool_->Schedule(std::move(fn));
- }
- int NumThreads() const override { return pool_->NumThreads(); }
- int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
-
- private:
- Eigen::ThreadPool* pool_ = nullptr;
-};
-
-// We have a single global threadpool for all convolution operations. This means
-// that inferences started from different threads may block each other, but
-// since the underlying resource of CPU cores should be consumed by the
-// operations anyway, it shouldn't affect overall performance.
-const Eigen::ThreadPoolDevice& GetThreadPoolDevice() {
- const int thread_count = 4;
- static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count);
- static EigenThreadPoolWrapper* thread_pool_wrapper =
- new EigenThreadPoolWrapper(tp);
- static Eigen::ThreadPoolDevice* device =
- new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count);
- return *device;
-}
-
// Shorthands for the types we need when interfacing with the EigenTensor
// library.
typedef Eigen::TensorMap<
@@ -98,13 +69,13 @@ struct MatMulConvFunctor {
template <class T>
class EigenTensorConvFunctor {
private:
- Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) {
+ Eigen::PaddingType RuntimePadding2EigenPadding(PaddingType padding) {
switch (padding) {
- case kTfLitePaddingValid:
+ case PaddingType::kValid:
return Eigen::PADDING_VALID;
- case kTfLitePaddingSame:
+ case PaddingType::kSame:
return Eigen::PADDING_SAME;
- case kTfLitePaddingUnknown:
+ case PaddingType::kNone:
assert(false); // should never get here.
return Eigen::PADDING_VALID;
}
@@ -113,14 +84,13 @@ class EigenTensorConvFunctor {
}
public:
- void operator()(const T* input_data, T* im2col_buffer, int input_batches,
- int input_height, int input_width, int input_depth,
- const T* filter_data, int filter_height, int filter_width,
- int filter_count, int stride_rows, int stride_cols,
- int pad_width, int pad_height, TfLitePadding padding,
- T* output_data, int output_height, int output_width) {
- const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice();
-
+ void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data,
+ T* im2col_buffer, int input_batches, int input_height,
+ int input_width, int input_depth, const T* filter_data,
+ int filter_height, int filter_width, int filter_count,
+ int stride_rows, int stride_cols, int pad_width,
+ int pad_height, PaddingType padding, T* output_data,
+ int output_height, int output_width) {
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
stride_rows == 1 && stride_cols == 1);
if (is_1x1_kernel) {
@@ -143,8 +113,8 @@ class EigenTensorConvFunctor {
filter_width * filter_height * input_depth;
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
- EigenMatrix output(output_data, 1, filter_count);
- ConstEigenMatrix input(input_data, 1, k);
+ EigenMatrix output(output_data, input_batches, filter_count);
+ ConstEigenMatrix input(input_data, input_batches, k);
ConstEigenMatrix filter(filter_data, k, filter_count);
MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
filter, dim_pair);
@@ -157,40 +127,51 @@ class EigenTensorConvFunctor {
input_depth, filter_count);
output.device(device) =
Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
- TfLitePadding2EigenPadding(padding));
+ RuntimePadding2EigenPadding(padding));
}
}
};
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, TfLitePadding padding,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void Conv(const Eigen::ThreadPoolDevice& device,
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const PaddingType padding = params.padding_type;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
EigenTensorConvFunctor<float> conv_functor;
- conv_functor(input_data, im2col_data, batches, input_height, input_width,
- input_depth, filter_data, filter_height, filter_width,
- output_depth, stride_height, stride_width, pad_height, pad_width,
- padding, output_data, output_height, output_width);
+ conv_functor(device, input_data, im2col_data, batches, input_height,
+ input_width, input_depth, filter_data, filter_height,
+ filter_width, output_depth, stride_height, stride_width,
+ pad_height, pad_width, padding, output_data, output_height,
+ output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
- bias_data, bias_dims, output_data, output_dims, output_activation_min,
- output_activation_max);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace multithreaded_ops
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 08f7cfa5a5..36c15dbc57 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
@@ -55,83 +55,33 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
const int postamble_start =
m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
- // The arrays used to cache the vector.
- void* aligned_vector_cache_free = nullptr;
- float32x4_t* vector_cache_float32x4 =
- reinterpret_cast<float32x4_t*>(aligned_alloc(
- sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
- &aligned_vector_cache_free));
-
- const int kUnrollSize = 2;
for (int b = 0; b < n_batch; b++) {
float* result_in_batch = result + b * m_rows * result_stride;
const float* vector_in_batch = vector + b * m_cols;
+ const float* matrix_row = matrix;
- const float* matrix_ptr0 = matrix;
- // If there is only 1 row, we don't want to assign an illegal pointer.
- const float* matrix_ptr1 = nullptr;
- if (m_rows > 1) {
- matrix_ptr1 = matrix + m_cols;
- }
-
- // Cache the vector.
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
- }
-
- // Main matrix by vector multiplication loop, which handles two rows of
- // matrix by vector multiplication.
- for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
+ // Main matrix by vector multiplication loop
+ for (int r = 0; r < m_rows; r++) {
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
- acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
+ // Load 4 float values from vector and matrix row.
+ float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
+ float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
+ // Multiply the vector and matrix row and add to accumulator.
+ acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
}
// Add the 4 intermediate sum values to get the final dot-prod value for
// this column.
*result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- *(result_in_batch + result_stride) +=
- (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
- vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
+ (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- *(result_in_batch + result_stride) +=
- matrix_ptr1[c] * vector_in_batch[c];
- }
- matrix_ptr0 += kUnrollSize * m_cols;
- matrix_ptr1 += kUnrollSize * m_cols;
- result_in_batch += kUnrollSize * result_stride;
- }
- for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ *result_in_batch += matrix_row[c] * vector_in_batch[c];
}
- // Add the 4 intermediate sum values to get the final dot-prod value for
- // this column.
- *result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- }
- matrix_ptr0 += m_cols;
+ matrix_row += m_cols;
result_in_batch += result_stride;
}
}
- free(aligned_vector_cache_free);
}
void NeonMatrixBatchVectorMultiplyAccumulate(
@@ -162,7 +112,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Copy the vector data to an aligned vector.
memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
// Compute dot-product for every column.
@@ -232,7 +182,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int32 neon_sum =
vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
- *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv);
+ *result += ((neon_sum + postable_sum) * batch_scaling_factor);
} // for row
} // for batch
@@ -286,6 +236,35 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vectors.
+ float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
+ // Multiply.
+ float32x4_t result_f32x4 = vmulq_f32(batch_vector_f32x4, vector_f32x4);
+ // Store.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = vector[v] * batch_vector[v];
+ }
+ // Update the pointers.
+ result += v_size;
+ batch_vector += v_size;
+ }
+}
+
void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
@@ -296,17 +275,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
const int postamble_start =
v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
- // The arrays used to cache the vector.
- void* aligned_vector_cache_free = nullptr;
- float32x4_t* vector_cache_float32x4 =
- reinterpret_cast<float32x4_t*>(aligned_alloc(
- sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
- &aligned_vector_cache_free));
-
- for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
- vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v);
- }
-
float* result_ptr = result;
const float* batch_vector_ptr = batch_vector;
for (int b = 0; b < n_batch; b++) {
@@ -314,9 +282,9 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
// Load from memory to vectors.
float32x4_t result_f32x4 = vld1q_f32(result_ptr + v);
float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
// Multiply-accumulate.
- result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4,
- vector_cache_float32x4[v >> 2]);
+ result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, vector_f32x4);
// Store.
vst1q_f32(result_ptr + v, result_f32x4);
}
@@ -328,7 +296,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
result_ptr += v_size;
batch_vector_ptr += v_size;
}
- free(aligned_vector_cache_free);
}
void NeonSub1Vector(const float* vector, int v_size, float* result) {
@@ -352,6 +319,30 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) {
}
}
+bool NeonIsZeroVector(const float* vector, int v_size) {
+ // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot
+ // use the main vectorized loop, and we need to process sequentially.
+ // postamble_start shows the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ const float32x4_t zero_x4_float = vmovq_n_f32(0.0f);
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ const float32x4_t i_x4_float = vld1q_f32(vector + v);
+ uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float);
+ if (vgetq_lane_u32(cmp_result, 0) == 0) return false;
+ if (vgetq_lane_u32(cmp_result, 1) == 0) return false;
+ if (vgetq_lane_u32(cmp_result, 2) == 0) return false;
+ if (vgetq_lane_u32(cmp_result, 3) == 0) return false;
+ }
+
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; ++v) {
+ if (vector[v] != 0.0) return false;
+ }
+ return true;
+}
+
void NeonClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
// If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
@@ -380,6 +371,77 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit,
}
}
+void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
+ const float scale, float* result) {
+ // Here the assumption is that each buffer is 4-byte aligned.
+ const int kWeightsPerUint32 = 4;
+ TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kWeightsPerUint32 - 1), 0);
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int kWeightsPerNeonLane = 16;
+ const int postamble_start = v_size - (v_size & (kWeightsPerNeonLane - 1));
+
+ // Create a vector of 4 floats with the scale value.
+ const float32x4_t scale_f32x4 = vdupq_n_f32(scale);
+ int v = 0;
+ for (; v < postamble_start; v += kWeightsPerNeonLane) {
+ // Load int8 values, sixteen at a time.
+ const int8x16_t v_i8x16 = vld1q_s8(vector + v);
+ // Split it into two components of size eight.
+ const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16);
+ const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16);
+ // Convert both components to int16 first.
+ const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8);
+ const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8);
+ // Split each of them into two components each.
+ const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8);
+ const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8);
+ const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8);
+ const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8);
+ // Convert these to int32 and then to float.
+ float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
+ float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
+ float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4));
+ float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4));
+ // Vector multiply four floats at a time.
+ v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
+ v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
+ v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4);
+ v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4);
+ // Store the results.
+ vst1q_f32(result + v, v0_f32x4);
+ vst1q_f32(result + v + 4, v1_f32x4);
+ vst1q_f32(result + v + 8, v2_f32x4);
+ vst1q_f32(result + v + 12, v3_f32x4);
+ }
+
+ if (v_size - postamble_start >= (kWeightsPerNeonLane >> 1)) {
+ // Load eight int8 values, if there is at least eight remaining.
+ const int8x8_t v_i8x8 = vld1_s8(vector + v);
+ // Convert them to int16 first.
+ const int16x8_t v_i16x8 = vmovl_s8(v_i8x8);
+ // Split it into two components.
+ const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8);
+ const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8);
+ // Convert the components two floats.
+ float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
+ float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
+ // Vector multiply four floats at a time.
+ v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
+ v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
+ // Store the results.
+ vst1q_f32(result + v, v0_f32x4);
+ vst1q_f32(result + v + 4, v1_f32x4);
+ v += (kWeightsPerNeonLane >> 1);
+ }
+
+ // Postamble loop.
+ for (; v < v_size; v++) {
+ result[v] = scale * vector[v];
+ }
+}
+
void NeonSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
float* max, float* scaling_factor) {
@@ -394,13 +456,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = kScale / range;
const int postamble_start =
size - (size & (2 * kFloatWeightsPerNeonLane - 1));
// Vectorized constants.
- const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor);
+ const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
@@ -452,7 +515,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
for (int i = postamble_start; i < size; ++i) {
const int32 quantized_value =
- static_cast<int32>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 9e60d0657b..630a6bbf29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
@@ -52,6 +52,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ NEON_OR_PORTABLE(VectorBatchVectorCwiseProduct, vector, v_size, batch_vector,
+ n_batch, result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
@@ -72,6 +79,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
n_batch, result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -100,16 +112,25 @@ void ZeroVector(float* vector, int v_size) {
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+// Check if all entries of a vector are zero.
+bool IsZeroVector(const float* vector, int v_size) {
+ return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
+}
+
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result) {
+ NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
+}
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
}
void SymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min, float* max,
- float* scaling_factor) {
- NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, min,
- max, scaling_factor);
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor) {
+ NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
+ min_value, max_value, scaling_factor);
}
void VectorShiftLeft(float* vector, int v_size, float shift_value) {
@@ -122,6 +143,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 64ba5e62f6..77f84e0c1c 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
#include <assert.h>
#include <stdint.h>
@@ -34,22 +34,58 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
-using reference_ops::BroadcastGreater;
-using reference_ops::BroadcastGreaterEqual;
-using reference_ops::BroadcastLess;
-using reference_ops::BroadcastLessEqual;
+using reference_ops::ArgMax;
+using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
+using reference_ops::BroadcastAdd4DSlow;
+using reference_ops::BroadcastMul4DSlow;
+using reference_ops::BroadcastSub4DSlow;
+using reference_ops::Concatenation;
+using reference_ops::ConcatenationWithScaling;
+using reference_ops::DepthConcatenation;
+using reference_ops::Dequantize;
+using reference_ops::Div;
+using reference_ops::FakeQuant;
+using reference_ops::Gather;
using reference_ops::Greater;
using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
using reference_ops::Less;
using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
+using reference_ops::Mean;
using reference_ops::RankOneSelect;
+using reference_ops::Relu1;
+using reference_ops::Relu6;
+using reference_ops::ReluX;
using reference_ops::Select;
+using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
+using reference_ops::StridedSlice;
+using reference_ops::Transpose;
+
+// TODO(b/80247582) Remove this constant.
+// This will be phased out as the shifts are revised with more thought. Use of a
+// constant enables us to track progress on this work.
+//
+// Used to convert from old-style shifts (right) to new-style (left).
+static constexpr int kReverseShift = -1;
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
@@ -65,9 +101,9 @@ using VectorMap = typename std::conditional<
Eigen::Dynamic, 1>>,
Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
-template <typename Scalar, int N>
-VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
- const int size = FlatSize(dims);
+template <typename Scalar>
+VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
+ const int size = shape.FlatSize();
return VectorMap<Scalar>(data, size, 1);
}
@@ -81,25 +117,20 @@ using MatrixMap = typename std::conditional<
Eigen::Dynamic, Eigen::Dynamic>>,
Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
+ const RuntimeShape& shape) {
+ const int dims_count = shape.DimensionsCount();
+ const int rows = shape.Dims(dims_count - 1);
+ const int cols = FlatSizeSkipDim(shape, dims_count - 1);
return MatrixMap<Scalar>(data, rows, cols);
}
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
- const Dims<N>& dims) {
- const int cols = dims.sizes[N - 1];
- int rows = 1;
- for (int d = 0; d < N - 1; d++) {
- rows *= dims.sizes[d];
- }
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
+ const RuntimeShape& shape) {
+ const int cols = shape.Dims(0);
+ const int rows = FlatSizeSkipDim(shape, 0);
return MatrixMap<Scalar>(data, rows, cols);
}
@@ -110,147 +141,88 @@ using ArrayMap = typename std::conditional<
Eigen::Dynamic, Eigen::Dynamic>>,
Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
-template <typename Scalar, int N>
-ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
+template <typename Scalar>
+ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
+ const RuntimeShape& shape) {
+ const int dims_count = shape.DimensionsCount();
+ const int rows = shape.Dims(dims_count - 1);
+ const int cols = FlatSizeSkipDim(shape, dims_count - 1);
return ArrayMap<Scalar>(data, rows, cols);
}
+// Copied from tensorflow/core/framework/tensor_types.h
+template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
+struct TTypes {
+ // Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
+ Eigen::Aligned>
+ Flat;
+ typedef Eigen::TensorMap<
+ Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
+ UnalignedConstMatrix;
+};
+
// TODO(b/62193649): this function is only needed as long
// as we have the --variable_batch hack.
-template <typename Scalar, int N>
+template <typename Scalar>
MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
- const Dims<N>& dims,
+ const RuntimeShape& shape,
int rows) {
- int cols = 1;
- bool matched_rows = false;
- for (int d = 0; d < N; d++) {
- cols *= dims.sizes[d];
- if (cols == rows) {
- matched_rows = true;
- cols = 1;
- }
- }
- TFLITE_DCHECK(matched_rows);
+ const int flatsize = shape.FlatSize();
+ TFLITE_DCHECK_EQ(flatsize % rows, 0);
+ const int cols = flatsize / rows;
return MatrixMap<Scalar>(data, rows, cols);
}
-// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
-// BROADCASTING.
-//
-// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
-// rectangular array of numbers.
-//
-// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
-// However, as Dims<N> is to be deprecated, this class exists as an adaptor
-// to enable simple unoptimized implementations of element-wise broadcasting
-// operations.
-template <int N>
-struct NdArrayDesc {
- // The "extent" of each dimension. Indices along dimension d must be in the
- // half-open interval [0, extents[d]).
- int extents[N];
-
- // The number of *elements* (not bytes) between consecutive indices of each
- // dimension.
- int strides[N];
-};
-
-// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
-// ELEMENT-WISE BROADCASTING.
-//
-// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
-inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
- int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
- TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
- TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
- TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
- return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
- i3 * desc.strides[3];
-}
-
-// Given the dimensions of the operands for an element-wise binary broadcast,
-// adjusts them so that they can be directly iterated over with simple loops.
-// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
-// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
-//
-// This function assumes that the two input shapes are compatible up to
-// broadcasting and the shorter one has already been prepended with 1s to be the
-// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
-// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
-// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
-// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
-//
-// When two shapes are compatible up to broadcasting, for each dimension d,
-// the input extents are either equal, or one of them is 1.
-//
-// This function performs the following for each dimension d:
-// - If the extents are equal, then do nothing since the loop that walks over
-// both of the input arrays is correct.
-// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
-// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
-// array0 to be referenced *at any index* in dimension d and still access the
-// same slice.
-template <int N>
-inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
- const Dims<N>& input1_dims,
- NdArrayDesc<N>* desc0_out,
- NdArrayDesc<N>* desc1_out) {
- TFLITE_DCHECK(desc0_out != nullptr);
- TFLITE_DCHECK(desc1_out != nullptr);
-
- // Copy dims to desc.
- for (int i = 0; i < N; ++i) {
- desc0_out->extents[i] = input0_dims.sizes[i];
- desc0_out->strides[i] = input0_dims.strides[i];
- desc1_out->extents[i] = input1_dims.sizes[i];
- desc1_out->strides[i] = input1_dims.strides[i];
- }
-
- // Walk over each dimension. If the extents are equal do nothing.
- // Otherwise, set the desc with extent 1 to have extent equal to the other and
- // stride 0.
- for (int i = 0; i < N; ++i) {
- const int extent0 = ArraySize(input0_dims, i);
- const int extent1 = ArraySize(input1_dims, i);
- if (extent0 != extent1) {
- if (extent0 == 1) {
- desc0_out->strides[i] = 0;
- desc0_out->extents[i] = extent1;
- } else {
- TFLITE_DCHECK_EQ(extent1, 1);
- desc1_out->strides[i] = 0;
- desc1_out->extents[i] = extent0;
- }
- }
- }
-}
-
-inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
- for (int i = 0; i < 4; i++) {
- if (dims1.sizes[i] != dims2.sizes[i]) {
- return false;
- }
- }
- return true;
-}
-
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims,
- float output_activation_min,
- float output_activation_max) {
+// This is like the template-parameter version, except that the power-of-two is
+// passed as a function parameter. The template version is to be preferred,
+// since some target hardware optimizations depend on the range of the exponent.
+template <typename IntegerType>
+IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
+ if (exponent == 0) {
+ return x;
+ }
+ using ScalarIntegerType =
+ typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ const IntegerType min =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
+ const IntegerType max =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
+
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
+ const IntegerType positive_mask =
+ gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
+ const IntegerType negative_mask =
+ gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
+
+ IntegerType result = gemmlowp::ShiftLeft(x, exponent);
+ result = gemmlowp::SelectUsingMask(positive_mask, max, result);
+ result = gemmlowp::SelectUsingMask(negative_mask, min, result);
+ return result;
+}
+
+// This is like the template-parameter version, except that the power-of-two is
+// passed as a function parameter. See raw-integer version for further comments.
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits>
+SaturatingRoundingMultiplyByPOTParam(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
+}
+
+inline void AddBiasAndEvalActivationFunction(float output_activation_min,
+ float output_activation_max,
+ const RuntimeShape& bias_shape,
+ const float* bias_data,
+ const RuntimeShape& array_shape,
+ float* array_data) {
#ifdef USE_NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
float* array_ptr = array_data;
float* array_end_ptr = array_ptr + array_size;
@@ -300,8 +272,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
}
#else // not NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
for (int array_offset = 0; array_offset < array_size;
array_offset += bias_size) {
@@ -314,19 +286,6 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
#endif
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
- output_activation_min,
- output_activation_max);
-}
-
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
@@ -360,21 +319,24 @@ inline void optimized_ops_preload_l1_keep(const uint8* ptr) {
// to a matrix*vector product. LSTM cells contain a fully-connected node;
// when quantized, this becomes a special type of GEMV operation where
// the output is 16bit-quantized, thus needs its own special path.
-inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
- const uint8* weights_data,
- const Dims<4>& weights_dims,
- uint8 weights_zero_point, const int32* bias_data,
- const Dims<4>& bias_dims, int32 accum_multiplier,
- int accum_shift, int16* output_data,
- const Dims<4>& output_dims) {
+inline void GEMVForLstmCell(const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& weights_shape,
+ const uint8* weights_data, uint8 weights_zero_point,
+ const RuntimeShape& bias_shape,
+ const int32* bias_data, int32 accum_multiplier,
+ int accum_shift, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
// This special fast path for quantized LSTM cells does not try to support
// odd sizes that we haven't encountered in any LSTM cell, that would
// require special code (that would go untested until any LSTM cell
@@ -547,18 +509,21 @@ inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
#ifdef GEMMLOWP_NEON
inline void GEMVForLstmCellWithSymmetricRange(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 accum_multiplier,
- int accum_shift, int16* output_data, const Dims<4>& output_dims) {
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& weights_shape, const uint8* weights_data,
+ const RuntimeShape& bias_shape, const int32* bias_data,
+ int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
// This special fast path for quantized LSTM cells does not try to support
// odd sizes that we haven't encountered in any LSTM cell, that would
// require special code (that would go untested until any LSTM cell
@@ -834,14 +799,16 @@ inline void GEMVForLstmCellWithSymmetricRange(
}
#endif
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("FullyConnected");
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+
// TODO(b/62193649): this convoluted shape computation (determining
// input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
// is because the current --variable_batch hack consists in overwriting the
@@ -850,50 +817,42 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
// When that is fixed, this should become:
// const auto input_matrix_map =
// MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- const int input_rows = ArraySize(weights_dims, 0);
+ const int dims_count = weights_shape.DimensionsCount();
+ const int input_rows = weights_shape.Dims(dims_count - 1);
const auto input_matrix_map =
- MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
+ MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
const auto filter_matrix_map =
- MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
+ MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
}
#ifdef USE_NEON
inline void FullyConnectedAsGEMV(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ int32 input_offset, const RuntimeShape& filter_shape,
+ const uint8* filter_data, int32 filter_offset,
+ const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ int32 output_activation_max, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
static constexpr int kPeel = 4;
+ const bool shift_left = (output_shift > 0);
for (int k = 0; k < input_size; k += 64) {
optimized_ops_preload_l1_stream(input_data + k);
}
@@ -1005,11 +964,17 @@ inline void FullyConnectedAsGEMV(
int32x4_t bias_vec = vld1q_s32(bias_ptr);
bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
- // Multiply by the fixed-point multiplier.
- reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
- // Rounding-shift-right.
- using gemmlowp::RoundingDivideByPOT;
- reduced = RoundingDivideByPOT(reduced, output_shift);
+ if (shift_left) {
+ const int32 multiplier_power_of_two = 1 << output_shift;
+ reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ } else {
+ // Multiply by the fixed-point multiplier.
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ // Rounding-shift-right.
+ using gemmlowp::RoundingDivideByPOT;
+ reduced = RoundingDivideByPOT(reduced, -output_shift);
+ }
// Add the output offset.
const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
reduced = vaddq_s32(reduced, output_offset_vec);
@@ -1031,23 +996,22 @@ inline void FullyConnectedAsGEMV(
struct GemmlowpOutputPipeline {
typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
ColVectorMap;
- typedef std::tuple<
- gemmlowp::OutputStageBiasAddition<ColVectorMap>,
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
- gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8>
+ typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
+ gemmlowp::OutputStageClamp,
+ gemmlowp::OutputStageSaturatingCastToUint8>
Pipeline;
- static Pipeline Make(const int32* bias_data, int output_rows,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max) {
+ static Pipeline MakeExp(const int32* bias_data, int output_rows,
+ int32 output_offset, int32 output_multiplier,
+ int output_left_shift, int32 output_activation_min,
+ int32 output_activation_max) {
ColVectorMap bias_vector(bias_data, output_rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
- quantize_down_stage;
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
quantize_down_stage.result_offset_after_shift = output_offset;
quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
- quantize_down_stage.result_shift = output_shift;
+ quantize_down_stage.result_exponent = output_left_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -1057,42 +1021,47 @@ struct GemmlowpOutputPipeline {
}
};
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
#ifdef USE_NEON
- const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
if (batches == 1 && !(output_size % 4)) {
return FullyConnectedAsGEMV(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
+ input_shape, input_data, input_offset, filter_shape, filter_data,
+ filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_shape, output_data);
}
#endif // USE_NEON
- const int filter_rows = filter_dims.sizes[1];
- const int filter_cols = filter_dims.sizes[0];
- TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
- const int output_rows = output_dims.sizes[0];
+ const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
+ const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
+ TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
+ const int output_rows = output_shape.Dims(output_dim_count - 1);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, output_rows, filter_cols, filter_cols);
@@ -1100,7 +1069,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
input_data, filter_cols, batches, filter_cols);
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, batches, output_rows);
- const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
@@ -1110,29 +1079,38 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
}
inline void FullyConnected(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
- int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data_int32, const RuntimeShape& output_shape,
+ int16* output_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
(void)gemm_context; // only used in properly optimized code.
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_EQ(output_offset, 0);
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
// Implementation of the fully connected node suited to the inside of an LSTM
// cell. The operands are 8-bit integers, the accumulators are internally
@@ -1143,17 +1121,17 @@ inline void FullyConnected(
if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
output_activation_max == 32767) {
if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
- GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
- filter_dims, bias_data_int32, bias_dims,
- output_multiplier, -output_shift,
- output_data, output_dims);
+ GEMVForLstmCellWithSymmetricRange(
+ input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data_int32, output_multiplier, output_shift, output_shape,
+ output_data);
return;
}
if (!(output_depth % 4) && !(accum_depth % 8)) {
- GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims,
- filter_offset, bias_data_int32, bias_dims,
- output_multiplier, -output_shift, output_data,
- output_dims);
+ GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
+ filter_offset, bias_shape, bias_data_int32,
+ output_multiplier, output_shift, output_shape,
+ output_data);
return;
}
}
@@ -1173,7 +1151,7 @@ inline void FullyConnected(
scale_stage.result_offset_after_shift = 0;
scale_stage.result_fixedpoint_multiplier = output_multiplier;
// Note that this shift is negated wrt ordinary FC.
- scale_stage.result_exponent = -output_shift;
+ scale_stage.result_exponent = output_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -1187,34 +1165,12 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
// Internal function doing the actual arithmetic work for
-// ExperimentalShuffledFullyConnected.
+// ShuffledFullyConnected.
// May be called either directly by it (single-threaded case) or may be used
// as the 'task' for worker threads to run (multi-threaded case, see
-// ExperimentalShuffledFullyConnectedWorkerTask below).
-inline void ExperimentalShuffledFullyConnectedWorkerImpl(
+// ShuffledFullyConnectedWorkerTask below).
+inline void ShuffledFullyConnectedWorkerImpl(
const uint8* shuffled_input_workspace_data,
const int8* shuffled_weights_data, int batches, int output_depth,
int output_stride, int accum_depth, const int32* bias_data,
@@ -1222,8 +1178,8 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl(
#if defined USE_NEON
const int8* shuffled_weights_ptr = shuffled_weights_data;
if (batches == 1) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
// Accumulation loop.
int32x4_t row_accum0 = vdupq_n_s32(0);
@@ -1289,8 +1245,8 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl(
vst1_s16(output_data + c, res16);
}
} else if (batches == 4) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
const int8* shuffled_input_ptr =
reinterpret_cast<const int8*>(shuffled_input_workspace_data);
@@ -1421,8 +1377,8 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl(
// (16-bit, typically 3 integer bits) fixed-point format. The quantized
// multiplier and shift here have been pre-computed offline
// (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1473,7 +1429,7 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl(
// quantized multiplier and shift here have been pre-computed offline
// (e.g. by toco).
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1488,14 +1444,16 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl(
#endif
}
-// Wraps ExperimentalShuffledFullyConnectedWorkerImpl into a Task class
+// Wraps ShuffledFullyConnectedWorkerImpl into a Task class
// to allow using gemmlowp's threadpool.
-struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
- ExperimentalShuffledFullyConnectedWorkerTask(
- const uint8* input_data, const int8* shuffled_weights_data, int batches,
- int output_depth, int output_stride, int accum_depth,
- const int32* bias_data, int32 output_multiplier, int output_shift,
- int16* output_data)
+struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task {
+ ShuffledFullyConnectedWorkerTask(const uint8* input_data,
+ const int8* shuffled_weights_data,
+ int batches, int output_depth,
+ int output_stride, int accum_depth,
+ const int32* bias_data,
+ int32 output_multiplier, int output_shift,
+ int16* output_data)
: input_data_(input_data),
shuffled_weights_data_(shuffled_weights_data),
batches_(batches),
@@ -1508,7 +1466,7 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
output_data_(output_data) {}
void Run() override {
- ExperimentalShuffledFullyConnectedWorkerImpl(
+ ShuffledFullyConnectedWorkerImpl(
input_data_, shuffled_weights_data_, batches_, output_depth_,
output_stride_, accum_depth_, bias_data_, output_multiplier_,
output_shift_, output_data_);
@@ -1526,28 +1484,35 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
int16* output_data_;
};
-inline void ExperimentalShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- gemmlowp::ScopedProfilingLabel label(
- "ExperimentalShuffledFullyConnected/8bit");
+inline void ShuffledFullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
(void)gemm_context; // only used in optimized code.
TFLITE_DCHECK_EQ(output_activation_min, -32768);
TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
TFLITE_DCHECK((accum_depth % 16) == 0);
TFLITE_DCHECK((output_depth % 4) == 0);
// Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
@@ -1618,7 +1583,7 @@ inline void ExperimentalShuffledFullyConnected(
if (thread_count == 1) {
// Single-thread case: do the computation on the current thread, don't
// use a threadpool
- ExperimentalShuffledFullyConnectedWorkerImpl(
+ ShuffledFullyConnectedWorkerImpl(
shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
output_depth, output_depth, accum_depth, bias_data, output_multiplier,
output_shift, output_data);
@@ -1633,7 +1598,7 @@ inline void ExperimentalShuffledFullyConnected(
int row_start = 0;
for (int i = 0; i < thread_count; i++) {
int row_end = std::min(output_depth, row_start + kRowsPerWorker);
- tasks[i] = new ExperimentalShuffledFullyConnectedWorkerTask(
+ tasks[i] = new ShuffledFullyConnectedWorkerTask(
shuffled_input_workspace_data,
int8_shuffled_weights_data + row_start * accum_depth, batches,
row_end - row_start, output_depth, accum_depth, bias_data + row_start,
@@ -1645,12 +1610,16 @@ inline void ExperimentalShuffledFullyConnected(
}
template <typename T>
-inline void ExtractPatchIntoBufferColumn(
- const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int in_width, int in_height, int in_depth, int single_buffer_length,
- int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
+ int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height,
+ int pad_width, int pad_height,
+ int in_width, int in_height,
+ int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data,
+ T* conv_buffer_data, uint8 zero_byte) {
gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
// This chunk of code reshapes all the inputs corresponding to
// output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
const int kwidth_times_indepth = kwidth * in_depth;
@@ -1672,7 +1641,7 @@ inline void ExtractPatchIntoBufferColumn(
const int output_row_offset = (buffer_id * single_buffer_length);
int out_offset =
output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
- int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+ int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
// Express all of the calculations as padding around the input patch.
const int top_padding = h_offset;
@@ -1686,7 +1655,7 @@ inline void ExtractPatchIntoBufferColumn(
// patch that are off the edge of the input image.
if (top_padding > 0) {
const int top_row_elements = (top_padding * kwidth * in_depth);
- memset(conv_buffer_data + output_row_offset, byte_zero,
+ memset(conv_buffer_data + output_row_offset, zero_byte,
(top_row_elements * sizeof(T)));
}
@@ -1703,14 +1672,14 @@ inline void ExtractPatchIntoBufferColumn(
for (int ih = ih_start; ih < ih_end; ++ih) {
if (left_padding > 0) {
const int left_start = (out_offset - (left_padding * in_depth));
- memset(conv_buffer_data + left_start, byte_zero,
+ memset(conv_buffer_data + left_start, zero_byte,
(left_padding * in_depth * sizeof(T)));
}
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
single_row_num * sizeof(T));
if (right_padding > 0) {
const int right_start = (out_offset + single_row_num);
- memset(conv_buffer_data + right_start, byte_zero,
+ memset(conv_buffer_data + right_start, zero_byte,
(right_padding * in_depth * sizeof(T)));
}
out_offset += kwidth_times_indepth;
@@ -1725,26 +1694,113 @@ inline void ExtractPatchIntoBufferColumn(
const int bottom_start =
output_row_offset +
((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
- memset(conv_buffer_data + bottom_start, byte_zero,
+ memset(conv_buffer_data + bottom_start, zero_byte,
(bottom_row_elements * sizeof(T)));
}
}
template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int kheight,
- int kwidth, uint8 byte_zero, T* output_data,
- const Dims<4>& output_dims) {
+void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ // For dilated convolution, the input pixels are not contiguous therefore we
+ // can't use the same opitimizations as Im2Col(). Though note this code would
+ // work fine for the non-dilated case too (though likely a bit slower).
+ gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
+ TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
+ TFLITE_DCHECK(im2col_data);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 0);
+
+ // Construct the MxN sized im2col matrix.
+ // The rows M, are sub-ordered B x H x W
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
+ // The columns, N, are sub-ordered Kh x Kw x Din
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
+ // Use dimensions M and N to construct dims for indexing directly into im2col
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
+
+ // Loop through the output rows (B x H x W)
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ // Each im2col row is an output pixel. Arrange the input data in this
+ // row in an order we can conveniently multiply with the filter data.
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Loop through all the pixels of the filter (Kh x Kw)
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ if ((in_y >= 0) && (in_y < input_height)) {
+ // Filter row is within the input data.
+ // Loop through all the filter pixels in this row.
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
+ T* dst = im2col_data +
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
+ if ((in_x >= 0) && (in_x < input_width)) {
+ // Filter pixel is within the input, copy the input data.
+ T const* src =
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
+ memcpy(dst, src, input_depth * sizeof(T));
+ } else {
+ // Filter pixel is outside the input, zero it out.
+ memset(dst, zero_byte, input_depth * sizeof(T));
+ }
+ }
+ } else {
+ // Filter row is outside the input, zero out the entire filter row.
+ int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
+ T* dst = im2col_data +
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
+ memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Im2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
int buffer_id = 0;
// Loop over the output nodes.
@@ -1752,252 +1808,241 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
ExtractPatchIntoBufferColumn(
- input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+ input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
pad_width, pad_height, input_width, input_height, input_depth,
- output_depth, buffer_id, input_data, output_data, byte_zero);
+ output_depth, buffer_id, input_data, output_data, zero_byte);
++buffer_id;
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
- Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
-}
-
-inline void DilatedConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- gemmlowp::ScopedProfilingLabel label("DilatedConv");
- // This is a copy of the reference Conv implementation. We do not currently
- // have an optimized path for dilation.
- (void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- if (bias_data) {
- TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
- }
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- for (int batch = 0; batch < batches; ++batch) {
- for (int out_y = 0; out_y < output_height; ++out_y) {
- for (int out_x = 0; out_x < output_width; ++out_x) {
- for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
- float total = 0.f;
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
- for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- const int in_x = in_x_origin + dilation_width_factor * filter_x;
- const int in_y =
- in_y_origin + dilation_height_factor * filter_y;
- // If the location is outside the bounds of the input image,
- // use zero as a default value.
- if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
- (in_y < input_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
- float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
- total += (input_value * filter_value);
- }
- }
- }
- }
- float bias_value = 0.0f;
- if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
- }
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(total + bias_value,
- output_activation_min,
- output_activation_max);
- }
- }
- }
- }
-}
-
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) {
- return DilatedConv(input_data, input_dims, filter_data, filter_dims,
- bias_data, bias_dims, stride_width, stride_height,
- dilation_width_factor, dilation_height_factor, pad_width,
- pad_height, output_activation_min, output_activation_max,
- output_data, output_dims, im2col_data, im2col_dims);
- }
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
(void)im2col_data;
- (void)im2col_dims;
+ (void)im2col_shape;
gemmlowp::ScopedProfilingLabel label("Conv");
+ // NB: static_cast<float>(0x00000000h) == 0.0f
+ const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const bool need_dilated_im2col =
+ dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
- if (need_im2col) {
+ if (need_dilated_im2col) {
+ DilatedIm2col(params, float_zero_byte, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
+ gemm_input_data = im2col_data;
+ gemm_input_shape = &im2col_shape;
+ } else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, 0, im2col_data,
- im2col_dims);
+ Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
// TODO(aselle): We need to make sure to not send im2col if it is not
// needed.
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
+ MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
-}
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
+}
+
+inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
+ const RuntimeShape& input_shape,
+ const int8_t* input_data,
+ const RuntimeShape& filter_shape,
+ const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const float* bias_data,
+ const RuntimeShape& output_shape, float* output_data,
+ const RuntimeShape& im2col_shape, int8_t* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batch_size = input_shape.Dims(0);
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+
+ const int8_t* gemm_input_data = nullptr;
+ int num_input;
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float* output_data, const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, dilation_width_factor,
- dilation_height_factor, pad_width, pad_height, output_activation_min,
- output_activation_max, output_data, output_dims, im2col_data,
- im2col_dims);
-}
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ // symmetric quantization assumes zero point of 0.
+ const int input_zero_point = 0;
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, 1, 1, pad_width, pad_height,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims);
-}
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
+ gemm_input_data = im2col_data;
+ num_input = im2col_shape.FlatSize();
+ } else {
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ num_input = input_shape.FlatSize();
+ }
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
- output_dims, im2col_data, im2col_dims);
-}
+ // Flatten 4D matrices into 2D matrices for matrix multiplication.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- gemmlowp::ScopedProfilingLabel label("Conv/8bit");
+ // Flatten so that each filter has its own row.
+ const int filter_rows = filter_shape.Dims(0);
+ const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ // In MatrixBatchVectorMultiplyAccumulate, each output value is the
+ // dot product of one row of the first matrix with one row of the second
+ // matrix. Therefore, the number of cols in each matrix are equivalent.
+ //
+ // After Im2Col, each input patch becomes a row.
+ const int gemm_input_cols = filter_cols;
+ const int gemm_input_rows = num_input / gemm_input_cols;
+
+ const int output_cols = output_shape.Dims(3);
+ const int output_rows = FlatSizeSkipDim(output_shape, 3);
+ TFLITE_DCHECK_EQ(output_cols, filter_rows);
+ TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
+
+ // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
+ // input matrix has its own scale factor. This code duplicates the scale
+ // factors for each row in the same batch.
+ const int rows_per_batch = gemm_input_rows / batch_size;
+ for (int i = gemm_input_rows - 1; i >= 0; --i) {
+ scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
+ }
+
+ tensor_utils::ZeroVector(output_data, output_rows * output_cols);
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ filter_data, filter_rows, filter_cols, gemm_input_data,
+ scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
+ /*result_stride=*/1);
+
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
+}
+
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("Conv/8bit");
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const uint8* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const bool need_dilated_im2col =
+ dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
- if (need_im2col) {
+ if (need_dilated_im2col) {
TFLITE_DCHECK(im2col_data);
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+ DilatedIm2col(params, input_zero_point, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
+ } else if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ const int input_zero_point = -input_offset;
+ TFLITE_DCHECK_GE(input_zero_point, 0);
+ TFLITE_DCHECK_LE(input_zero_point, 255);
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
+ gemm_input_data = im2col_data;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
- }
-
- const int gemm_input_rows = gemm_input_dims->sizes[0];
- const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
- const int output_rows = output_dims.sizes[0];
- const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ gemm_input_shape = &input_shape;
+ }
+
+ const int gemm_input_rows = gemm_input_shape->Dims(3);
+ // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
+ // The root cause has not yet been identified though. Same applies below for
+ // the other calls commented out. This is a partial rollback of cl/196819423.
+ // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
+ const int gemm_input_cols = gemm_input_shape->Dims(0) *
+ gemm_input_shape->Dims(1) *
+ gemm_input_shape->Dims(2);
+ const int filter_rows = filter_shape.Dims(0);
+ // See b/79927784.
+ // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
+ const int filter_cols =
+ filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
+ const int output_rows = output_shape.Dims(3);
+ // See b/79927784.
+ // const int output_cols = FlatSizeSkipDim(output_shape, 3);
+ const int output_cols =
+ output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, filter_rows, filter_cols);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
gemm_input_data, gemm_input_rows, gemm_input_cols);
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, output_cols);
- const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
@@ -2006,78 +2051,35 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, stride, pad_width,
- pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthToSpace");
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
- const int output_depth = ArraySize(output_dims, 0);
- const int batch_size = ArraySize(output_dims, 3);
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+
+ const int output_depth = output_shape.Dims(3);
+ const int batch_size = output_shape.Dims(0);
// Number of continuous values that we can copy in one interation.
- const int stride = block_size * output_depth;
+ const int stride = op_params.block_size * output_depth;
for (int batch = 0; batch < batch_size; ++batch) {
for (int in_h = 0; in_h < input_height; ++in_h) {
- const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
const T* src = input_ptr;
for (int in_w = 0; in_w < input_width; ++in_w) {
memcpy(output_data, src, stride * sizeof(T));
@@ -2090,100 +2092,35 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
- Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
-
- const auto input_matrix_map =
- MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
- auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
-
- Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
-
- AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
- output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- const int input_rows = input_dims.sizes[0];
- const int input_cols = FlatSizeSkipDim(input_dims, 0);
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
- const int output_rows = output_dims.sizes[0];
- const int output_cols = FlatSizeSkipDim(output_dims, 0);
- TFLITE_DCHECK_EQ(output_rows, filter_rows);
- TFLITE_DCHECK_EQ(output_cols, input_cols);
- TFLITE_DCHECK_EQ(filter_cols, input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
- gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
- filter_data, output_rows, filter_cols, filter_cols);
- gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
- input_data, filter_cols, output_cols, filter_cols);
- gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
- output_data, output_rows, output_cols, output_rows);
- const auto& output_pipeline = GemmlowpOutputPipeline::Make(
- bias_data, output_rows, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max);
- gemmlowp::GemmWithOutputPipeline<uint8, uint8,
- gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
- gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
- input_offset, output_pipeline);
-}
-
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
- const int input_depth = ArraySize(input_dims, 0);
- const int batch_size = ArraySize(input_dims, 3);
+ const int input_depth = input_shape.Dims(3);
+ const int batch_size = input_shape.Dims(0);
// Number of continuous values that we can copy in one interation.
- const int stride = block_size * input_depth;
+ const int stride = op_params.block_size * input_depth;
for (int batch = 0; batch < batch_size; ++batch) {
for (int out_h = 0; out_h < output_height; ++out_h) {
- T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
T* dst = output_ptr;
for (int out_w = 0; out_w < output_width; ++out_w) {
memcpy(dst, input_data, stride * sizeof(T));
@@ -2196,95 +2133,26 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void NonGlobalBatchNormalization(
- const float* input_data, const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims, const float* multiplier_data,
- const Dims<4>& multiplier_dims, const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int inner_size = MatchingFlatSizeSkipDim(
- input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
-
- for (int b = 0; b < batches; ++b) {
- for (int i = 0; i < inner_size; ++i) {
- *output_data = ActivationFunction<Ac>(
- (*input_data - mean_data[i]) * multiplier_data[i] + offset_data[i]);
- ++output_data;
- ++input_data;
- }
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void GlobalBatchNormalization(const float* input_data,
- const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims,
- const float* multiplier_data,
- const Dims<4>& multiplier_dims,
- const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth =
- MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
- offset_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- *output_data = ActivationFunction<Ac>(
- (*input_data - mean_data[c]) * multiplier_data[c] + offset_data[c]);
- ++output_data;
- ++input_data;
- }
- }
-}
-
-inline void Relu(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Relu(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
- const auto input = MapAsVector(input_data, input_dims);
- auto output = MapAsVector(output_data, output_dims);
+ const auto input = MapAsVector(input_data, input_shape);
+ auto output = MapAsVector(output_data, output_shape);
output = input.cwiseMax(0.0f);
}
-inline void Relu1(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- const float val = input_data[i];
- const float upper = 1;
- const float lower = -1;
- const float clamped = val > upper ? upper : val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-inline void Relu6(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- const float val = input_data[i];
- const float upper = 6;
- const float lower = 0;
- const float clamped = val > upper ? upper : val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Normalization");
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
float squared_l2_norm = 0;
for (int c = 0; c < depth; ++c) {
@@ -2300,15 +2168,17 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
- int* output_shift) {
+inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
+ int32* output_inv_sqrt,
+ int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
- const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bits =
+ CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
@@ -2343,51 +2213,58 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
+ *output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_zero_point, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- TFLITE_DCHECK_EQ(outer_size, 1);
- int32 square_l2_norm = 0;
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[i] - input_zero_point;
- square_l2_norm += diff * diff;
- }
- int32 inv_l2norm_multiplier;
- int inv_l2norm_shift;
- GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
- &inv_l2norm_shift);
-
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[i] - input_zero_point;
- int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
- 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
- int32 unclamped_output_val = 128 + rescaled_diff;
- int32 output_val = std::min(255, std::max(0, unclamped_output_val));
- output_data[i] = static_cast<uint8>(output_val);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32 input_zero_point = op_params.input_zero_point;
+ for (int i = 0; i < outer_size; ++i) {
+ int32 square_l2_norm = 0;
+ for (int c = 0; c < depth; c++) {
+ // Note that input_data advances by depth in the second pass below.
+ int32 diff = input_data[c] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int c = 0; c < depth; c++) {
+ int32 diff = *input_data - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ *output_data = static_cast<uint8>(output_val);
+ ++input_data;
+ ++output_data;
+ }
}
}
-inline void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Add");
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
int i = 0;
- const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
- const auto activation_min = vdupq_n_f32(output_activation_min);
- const auto activation_max = vdupq_n_f32(output_activation_max);
+ const auto activation_min = vdupq_n_f32(params.float_activation_min);
+ const auto activation_max = vdupq_n_f32(params.float_activation_max);
for (; i <= size - 16; i += 16) {
auto a10 = vld1q_f32(input1_data + i);
auto a11 = vld1q_f32(input1_data + i + 4);
@@ -2426,29 +2303,26 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims,
for (; i < size; i++) {
auto x = input1_data[i] + input2_data[i];
- output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
- output_activation_max);
+ output_data[i] = ActivationFunctionWithMinMax(
+ x, params.float_activation_min, params.float_activation_max);
}
}
// Element-wise add that can often be used for inner loop of broadcast add as
// well as the non-broadcast add.
-inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
- int32 input1_offset, int32 input1_multiplier,
- int input1_shift, const uint8* input2_data,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data) {
+inline void AddElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
int i = 0;
- TFLITE_DCHECK_GT(input1_offset, -256);
- TFLITE_DCHECK_GT(input2_offset, -256);
- TFLITE_DCHECK_LT(input1_offset, 256);
- TFLITE_DCHECK_LT(input2_offset, 256);
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
#ifdef USE_NEON
- const auto output_activation_min_vector = vdup_n_u8(output_activation_min);
- const auto output_activation_max_vector = vdup_n_u8(output_activation_max);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
for (; i <= size - 8; i += 8) {
const auto input1_val_original = vld1_u8(input1_data + i);
const auto input2_val_original = vld1_u8(input2_data + i);
@@ -2457,9 +2331,9 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
const auto input2_val_s16 =
vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
const auto input1_val =
- vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset));
+ vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
const auto input2_val =
- vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset));
+ vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
const auto input1_val_high = vget_high_s16(input1_val);
const auto input1_val_low = vget_low_s16(input1_val);
const auto input2_val_high = vget_high_s16(input2_val);
@@ -2468,32 +2342,32 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
auto x12 = vmovl_s16(input1_val_high);
auto x21 = vmovl_s16(input2_val_low);
auto x22 = vmovl_s16(input2_val_high);
- const auto left_shift_dup = vdupq_n_s32(left_shift);
+ const auto left_shift_dup = vdupq_n_s32(params.left_shift);
x11 = vshlq_s32(x11, left_shift_dup);
x12 = vshlq_s32(x12, left_shift_dup);
x21 = vshlq_s32(x21, left_shift_dup);
x22 = vshlq_s32(x22, left_shift_dup);
- x11 = vqrdmulhq_n_s32(x11, input1_multiplier);
- x12 = vqrdmulhq_n_s32(x12, input1_multiplier);
- x21 = vqrdmulhq_n_s32(x21, input2_multiplier);
- x22 = vqrdmulhq_n_s32(x22, input2_multiplier);
- const auto input1_shift_dup = vdupq_n_s32(-input1_shift);
- const auto input2_shift_dup = vdupq_n_s32(-input2_shift);
+ x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
+ x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
+ x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
+ x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
+ const auto input1_shift_dup = vdupq_n_s32(params.input1_shift);
+ const auto input2_shift_dup = vdupq_n_s32(params.input2_shift);
x11 = vshlq_s32(x11, input1_shift_dup);
x12 = vshlq_s32(x12, input1_shift_dup);
x21 = vshlq_s32(x21, input2_shift_dup);
x22 = vshlq_s32(x22, input2_shift_dup);
auto s1 = vaddq_s32(x11, x21);
auto s2 = vaddq_s32(x12, x22);
- s1 = vqrdmulhq_n_s32(s1, output_multiplier);
- s2 = vqrdmulhq_n_s32(s2, output_multiplier);
+ s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
+ s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
using gemmlowp::RoundingDivideByPOT;
- s1 = RoundingDivideByPOT(s1, output_shift);
- s2 = RoundingDivideByPOT(s2, output_shift);
+ s1 = RoundingDivideByPOT(s1, -params.output_shift);
+ s2 = RoundingDivideByPOT(s2, -params.output_shift);
const auto s1_narrowed = vmovn_s32(s1);
const auto s2_narrowed = vmovn_s32(s2);
const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
- vdupq_n_s16(output_offset));
+ vdupq_n_s16(params.output_offset));
const auto clamped =
vmax_u8(output_activation_min_vector,
vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
@@ -2502,108 +2376,74 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
#endif // NEON
for (; i < size; ++i) {
- const int32 input1_val = input1_offset + input1_data[i];
- const int32 input2_val = input2_offset + input2_data[i];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, params.input1_multiplier, params.input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, params.input2_multiplier, params.input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
- const int32 clamped_output = std::min(
- output_activation_max, std::max(output_activation_min, raw_output));
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
output_data[i] = static_cast<uint8>(clamped_output);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
gemmlowp::ScopedProfilingLabel label("Add/8bit");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-
- TFLITE_DCHECK_GT(input1_offset, -256);
- TFLITE_DCHECK_GT(input2_offset, -256);
- TFLITE_DCHECK_LT(input1_offset, 256);
- TFLITE_DCHECK_LT(input2_offset, 256);
- AddElementwise(flat_size, left_shift, input1_data, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ AddElementwise(flat_size, params, input1_data, input2_data, output_data);
}
-template <FusedActivationFunctionType Ac>
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Add/Int16");
- // This is a copy of the reference implementation. We do not currently have a
- // properly optimized version.
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
-
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
-
- TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
- TFLITE_DCHECK_GE(input1_shift, 0);
- TFLITE_DCHECK_GE(input2_shift, 0);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+
+ const int input1_shift = params.input1_shift;
+ const int flat_size =
+ MatchingFlatSize(output_shape, input1_shape, input2_shape);
+ const int16 output_activation_min = params.quantized_activation_min;
+ const int16 output_activation_max = params.quantized_activation_max;
+
+ TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+ TFLITE_DCHECK_LE(input1_shift, 0);
+ TFLITE_DCHECK_LE(params.input2_shift, 0);
const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
- const int input_shift = input1_shift == 0 ? input2_shift : input1_shift;
+ const int input_right_shift =
+ input1_shift == 0 ? -params.input2_shift : -input1_shift;
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
- F0 scaled_input =
- F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift));
+ F0 scaled_input = F0::FromRaw(
+ gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
const int16 raw_output = result.raw();
const int16 clamped_output = std::min(
@@ -2612,157 +2452,59 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void Add(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int32* input1_data,
+ const RuntimeShape& input2_shape, const int32* input2_data,
+ const RuntimeShape& output_shape, int32* output_data) {
gemmlowp::ScopedProfilingLabel label("Add/int32");
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
output_map.array() = input1_map.array() + input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
+ } else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
output_map.array() = input1_map.array() + scalar;
- } else if (FlatSize(input1_dims) == 1) {
+ } else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
output_map.array() = scalar + input2_map.array();
} else {
// Should not come here.
TFLITE_DCHECK(false);
}
+ output_map = output_map.cwiseMax(params.quantized_activation_min);
+ output_map = output_map.cwiseMin(params.quantized_activation_max);
}
-// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAddGeneric/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
- }
-}
-
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+ switched_params.input1_shift = unswitched_params.input2_shift;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+ switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+ switched_params.input2_shift = unswitched_params.input1_shift;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
// Fivefold nested loops. The second input resets its position for each
// iteration of the second loop. The first input resets its position at the
// beginning of the fourth loop. The innermost loop is an elementwise add of
@@ -2770,93 +2512,39 @@ inline void BroadcastAddFivefold(
uint8* output_data_ptr = output_data;
const uint8* input1_data_ptr = input1_data;
const uint8* input2_data_reset = input2_data;
- for (int i4 = 0; i4 < y4; ++i4) {
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
const uint8* input2_data_ptr;
- for (int i3 = 0; i3 < y3; ++i3) {
+ for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
- for (int i1 = 0; i1 < y1; ++i1) {
- AddElementwise(
- y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier,
- input1_shift, input2_data_ptr, input2_offset, input2_multiplier,
- input2_shift, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data_ptr);
- input2_data_ptr += y0;
- output_data_ptr += y0;
+ for (int i3 = 0; i3 < y3; ++i3) {
+ AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
}
- input1_data_ptr += y0;
+ input1_data_ptr += y4;
}
}
input2_data_reset = input2_data_ptr;
}
}
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_dims,
- input2_offset, input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims,
- input1_offset, input1_multiplier, input1_shift,
- input2_data, input2_dims, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul");
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
int i = 0;
- const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
const auto activation_min = vdupq_n_f32(output_activation_min);
const auto activation_max = vdupq_n_f32(output_activation_max);
@@ -2907,34 +2595,41 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int32* input1_data,
+ const RuntimeShape& input2_shape, const int32* input2_data,
+ const RuntimeShape& output_shape, int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] * input2_data[i], output_activation_min,
+ output_activation_max);
+ }
}
-template <FusedActivationFunctionType Ac>
-void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
+inline void MulNoActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/int32");
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
output_map.array() = input1_map.array() * input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
+ } else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
output_map.array() = input1_map.array() * scalar;
- } else if (FlatSize(input1_dims) == 1) {
+ } else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
output_map.array() = scalar * input2_map.array();
} else {
@@ -2943,14 +2638,16 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mul/Int16");
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mul/Int16/NoActivation");
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2962,17 +2659,20 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 output_offset = params.output_offset;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2990,216 +2690,256 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+// Element-wise mul that can often be used for inner loop of broadcast Mul as
+// well as the non-broadcast Mul.
+inline void MulElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ int i = 0;
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ TFLITE_DCHECK_GT(params.output_offset, -256);
+ TFLITE_DCHECK_LT(params.output_offset, 256);
+#ifdef USE_NEON
+ const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
+ const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
+ const auto output_offset_vector = vdupq_n_s16(params.output_offset);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
+ for (; i <= size - 8; i += 8) {
+ // We load / store 8 at a time, multiplying as two sets of 4 int32s.
+ const auto input1_val_original = vld1_u8(input1_data + i);
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input1_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
+ const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ const auto input1_val_low = vget_low_s16(input1_val);
+ const auto input1_val_high = vget_high_s16(input1_val);
+ const auto input2_val_low = vget_low_s16(input2_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ auto p1 = vmull_s16(input2_val_low, input1_val_low);
+ auto p2 = vmull_s16(input2_val_high, input1_val_high);
+
+ p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
+ p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ p1 = RoundingDivideByPOT(p1, -params.output_shift);
+ p2 = RoundingDivideByPOT(p2, -params.output_shift);
+
+ const auto p1_narrowed = vmovn_s32(p1);
+ const auto p2_narrowed = vmovn_s32(p2);
+ const auto p =
+ vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
+ const auto clamped =
+ vmax_u8(output_activation_min_vector,
+ vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
+ vst1_u8(output_data + i, clamped);
}
-}
+#endif // NEON
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
+ for (; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
}
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+// Broadcast mul that can often be used for inner loop of broadcast Mul.
+inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
+ const uint8 broadcast_value,
+ const uint8* input2_data, uint8* output_data) {
+ const int16 input1_val = params.input1_offset + broadcast_value;
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ int i = 0;
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ TFLITE_DCHECK_GT(params.output_offset, -256);
+ TFLITE_DCHECK_LT(params.output_offset, 256);
+#ifdef USE_NEON
+ const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
+ const auto output_offset_vector = vdupq_n_s16(params.output_offset);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
+ for (; i <= size - 8; i += 8) {
+ // We load / store 8 at a time, multiplying as two sets of 4 int32s.
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 unclamped_result =
- output_offset +
- MultiplyByQuantizedMultiplierSmallerThanOne(
- input1_val * input2_val, output_multiplier, output_shift);
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, unclamped_result));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+ const auto input2_val_low = vget_low_s16(input2_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
+
+ auto p1 = vmull_n_s16(input2_val_low, input1_val);
+ auto p2 = vmull_n_s16(input2_val_high, input1_val);
+
+ p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
+ p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ p1 = RoundingDivideByPOT(p1, -params.output_shift);
+ p2 = RoundingDivideByPOT(p2, -params.output_shift);
+
+ const auto p1_narrowed = vmovn_s32(p1);
+ const auto p2_narrowed = vmovn_s32(p2);
+ const auto p =
+ vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
+ const auto clamped =
+ vmax_u8(output_activation_min_vector,
+ vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
+ vst1_u8(output_data + i, clamped);
}
-}
+#endif // NEON
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
- input2_dims, input2_offset, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data, output_dims);
+ for (; i < size; ++i) {
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
}
-// TODO(aselle): This is not actually optimized yet.
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
- for (int i = 0; i < flat_size; i++) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] / input2_data[i], output_activation_min,
- output_activation_max);
- }
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ gemmlowp::ScopedProfilingLabel label("Mul/8bit");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
-// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMulFivefold/8bit");
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input2_offset = unswitched_params.input1_offset;
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise Mul of
+ // sections of the arrays.
+ uint8* output_data_ptr = output_data;
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ if (y4 > 1) {
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ for (int i3 = 0; i3 < y3; ++i3) {
+ MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
+ }
+ input1_data_ptr += y4;
}
}
+ input2_data_reset = input2_data_ptr;
+ }
+ } else {
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y3;
+ output_data_ptr += y3;
+ ++input1_data_ptr;
+ }
+ }
+ input2_data_reset = input2_data_ptr;
}
}
}
-// TODO(aselle): This is not actually optimized yet.
-inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
- }
-}
-
-// TODO(jiawen): We can implement BroadcastSub on buffers of arbitrary
+// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastSub is intentionally duplicated from
+// TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
// reference_ops.h.
template <typename T>
-void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastDiv4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -3212,14 +2952,14 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -3227,220 +2967,172 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
}
}
-inline void BroadcastSub(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
- const int32 raw_sub = scaled_input1_val - scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sub, output_multiplier, output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+// TODO(aselle): This is not actually optimized yet.
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Concatenation");
- int concat_size = 0;
- for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
- }
- }
- concat_size += ArraySize(*input_dims[i], concat_dim);
- }
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- // for now we dont have a model with a Concatenation
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
- }
- Scalar* output_ptr = output_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
- memcpy(output_ptr, input_data[i] + k * copy_size,
- copy_size * sizeof(Scalar));
- output_ptr += copy_size;
- }
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
-// TODO(prabhumk): This is the same as the reference implementation.
-// TODO(prabhumk): The quantized implementation of concatentation isn't fully
-// quantized as it takes scale as a floating point value. This should be fixed
-// when optimizng this routine further.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
- // The arguments input_zeropoint and input_scale are expected to be an array
- // that have the quantization parameters for all the inputs to the concat
- // operator.
- gemmlowp::ScopedProfilingLabel label("Concatenation");
- TFLITE_DCHECK_GT(inputs_count, 1);
- int concat_size = 0;
- for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
- }
- }
- concat_size += ArraySize(*input_dims[i], concat_dim);
- }
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
- }
- const float inverse_output_scale = 1.f / output_scale;
- uint8* output_ptr = output_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
- const uint8* input_ptr = input_data[i] + k * copy_size;
- if (input_zeropoint[i] == output_zeropoint &&
- input_scale[i] == output_scale) {
- memcpy(output_ptr, input_ptr, copy_size);
- } else {
- const float scale = input_scale[i] * inverse_output_scale;
- const float bias = -input_zeropoint[i] * scale;
- for (int j = 0; j < copy_size; ++j) {
- const int32_t value =
- static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
- output_zeropoint;
- output_ptr[j] =
- static_cast<uint8_t>(std::max(std::min(255, value), 0));
- }
- }
- output_ptr += copy_size;
- }
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void DepthConcatenation(const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
- output_data, output_dims);
-}
+template <typename T>
+void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Sub");
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
+ output_map.array() = input1_map.array() - input2_map.array();
+ } else if (input1_shape.FlatSize() == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar - input2_map.array();
+ } else if (input2_shape.FlatSize() == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() - scalar;
+ } else {
+ BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+ }
+}
+
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
gemmlowp::ScopedProfilingLabel label("LstmCell");
- MatchingArraySize( // batches
- input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims,
- 3, output_activ_dims, 3);
- MatchingArraySize( // height
- input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims,
- 2, output_activ_dims, 2);
- MatchingArraySize( // width
- input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims,
- 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ MatchingDim( // batches
+ input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
+ MatchingDim( // height
+ input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
+ MatchingDim( // width
+ input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
// Concatenate prev_activ and input data together
std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
concat_input_arrays_data.push_back(input_data);
concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
// Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
// Map raw arrays to Eigen arrays so we can use Eigen's optimized array
// operations.
ArrayMap<float> activ_temp_map =
- MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims);
+ MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
activ_temp_map.cols());
auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
@@ -3450,11 +3142,11 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
activ_temp_map.cols());
ArrayMap<const float> prev_state_map =
- MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims);
+ MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
ArrayMap<float> output_state_map =
- MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims);
+ MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
ArrayMap<float> output_activ_map =
- MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims);
+ MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
// Combined memory state and final output calculation
gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
@@ -3472,52 +3164,91 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
// reference_ops.h. See the big function comment there, not replicating it
// here.
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label(
"LstmCell/quantized (8bit external, 16bit internal)");
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -3527,10 +3258,10 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
bool gemm_already_performed = false;
#ifdef GEMMLOWP_NEON
if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
- GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims,
- weights_data_uint8, weights_dims, weights_zero_point,
- bias_data_int32, bias_dims, accum_multiplier, accum_shift,
- activ_temp_data_int16, activ_temp_dims);
+ GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
+ weights_data_uint8, weights_zero_point, bias_shape,
+ bias_data_int32, accum_multiplier, accum_shift,
+ activ_temp_shape, activ_temp_data_int16);
gemm_already_performed = true;
}
#endif
@@ -3719,51 +3450,28 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
- TFLITE_DCHECK_GE(outputs_count, 1);
- for (int i = 0; i < outputs_count; i++) {
- MatchingFlatSizeSkipDim(*output_dims[i], 0, input_dims);
- }
- const int outer_size = FlatSizeSkipDim(input_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- // For now we don't have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- const Scalar* input_ptr = input_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < outputs_count; ++i) {
- memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
- output_dims[i]->sizes[0] * sizeof(Scalar));
- input_ptr += output_dims[i]->sizes[0];
- }
- }
-}
-
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("AveragePool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
// TODO(benoitjacob) make this a proper reference impl without Eigen!
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// TODO(benoitjacob) get rid of the dynamic memory allocation here!
Eigen::VectorXf out_count(out_mat.cols());
out_count.setZero();
@@ -3774,12 +3482,15 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- int hpad = h + pad_height;
- int wpad = w + pad_width;
- int h_start =
- (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int hpad = h + params.padding_values.height;
+ int wpad = w + params.padding_values.width;
+ int h_start = (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
int h_end = std::min(hpad / stride_height + 1, output_height);
- int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_start = (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
int w_end = std::min(wpad / stride_width + 1, output_width);
// compute elementwise sum
for (int ph = h_start; ph < h_end; ++ph) {
@@ -3797,69 +3508,44 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
out_mat.array().rowwise() /= out_count.transpose().array();
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < output_height; ++y) {
- for (int x = 0; x < output_width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- output_data[Offset(output_dims, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
const int filter_count =
(filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
// 1280 required by Inception v3
@@ -3868,11 +3554,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
uint16 acc[kAccBufferMaxSize];
memset(acc, 0, depth * sizeof(acc[0]));
const uint8* input_ptr =
- input_data + input_dims.strides[1] * in_x_origin +
- input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ input_data +
+ depth * (in_x_origin +
+ input_width * (in_y_origin + input_height * batch));
for (int fy = filter_y_start; fy < filter_y_end; fy++) {
- const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
- filter_x_start * input_dims.strides[1];
+ const uint8* input_row_ptr =
+ input_ptr + depth * (fy * input_width + filter_x_start);
for (int fx = filter_x_start; fx < filter_x_end; fx++) {
int channel = 0;
#ifdef USE_NEON
@@ -3903,21 +3590,21 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
uint8* output_ptr =
- output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
-#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
- if (filter_count == FILTER_COUNT) { \
- for (; channel <= depth - 8; channel += 8) { \
- uint16 buf[8]; \
- for (int i = 0; i < 8; i++) { \
- buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
- } \
- uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
- buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \
- buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \
- vst1_u8(output_ptr + channel, buf8); \
- } \
+#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
+ if (filter_count == FILTER_COUNT) { \
+ for (; channel <= depth - 8; channel += 8) { \
+ uint16 buf[8]; \
+ for (int i = 0; i < 8; i++) { \
+ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
+ } \
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
+ buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
+ buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
+ vst1_u8(output_ptr + channel, buf8); \
+ } \
}
AVGPOOL_DIVIDING_BY(9)
AVGPOOL_DIVIDING_BY(15)
@@ -3928,15 +3615,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
}
uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
- buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
- buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
+ buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
+ buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
vst1_u8(output_ptr + channel, buf8);
}
#endif
for (; channel < depth; ++channel) {
uint16 a = (acc[channel] + filter_count / 2) / filter_count;
- a = std::max<uint16>(a, output_activation_min);
- a = std::min<uint16>(a, output_activation_max);
+ a = std::max<uint16>(a, params.quantized_activation_min);
+ a = std::min<uint16>(a, params.quantized_activation_max);
output_ptr[channel] = static_cast<uint8>(a);
}
}
@@ -3944,54 +3631,22 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("MaxPool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
+
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Prefill the output to minimum representable float value
out_mat.setConstant(std::numeric_limits<float>::lowest());
for (int b = 0; b < batches; ++b) {
@@ -3999,12 +3654,15 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- int hpad = h + pad_height;
- int wpad = w + pad_width;
- int h_start =
- (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int hpad = h + params.padding_values.height;
+ int wpad = w + params.padding_values.width;
+ int h_start = (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
int h_end = std::min(hpad / stride_height + 1, output_height);
- int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_start = (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
int w_end = std::min(wpad / stride_width + 1, output_width);
// compute elementwise sum
for (int ph = h_start; ph < h_end; ++ph) {
@@ -4019,78 +3677,55 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
}
}
}
-
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < output_height; ++y) {
- for (int x = 0; x < output_width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- output_data[Offset(output_dims, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int kwidth, int kheight, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
// 2048 required by Inception v3
static constexpr int kAccBufferMaxSize = 2048;
TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
uint8 acc[kAccBufferMaxSize];
memset(acc, 0, depth * sizeof(acc[0]));
const uint8* input_ptr =
- input_data + input_dims.strides[1] * in_x_origin +
- input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ input_data +
+ depth * (in_x_origin +
+ input_width * (in_y_origin + input_height * batch));
for (int fy = filter_y_start; fy < filter_y_end; fy++) {
- const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
- filter_x_start * input_dims.strides[1];
+ const uint8* input_row_ptr =
+ input_ptr + depth * (fy * input_width + filter_x_start);
for (int fx = filter_x_start; fx < filter_x_end; fx++) {
int channel = 0;
#ifdef USE_NEON
@@ -4116,26 +3751,26 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
uint8* output_ptr =
- output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
for (; channel <= depth - 16; channel += 16) {
uint8x16_t a = vld1q_u8(acc + channel);
- a = vminq_u8(a, vdupq_n_u8(output_activation_max));
- a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
+ a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
+ a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
vst1q_u8(output_ptr + channel, a);
}
for (; channel <= depth - 8; channel += 8) {
uint8x8_t a = vld1_u8(acc + channel);
- a = vmin_u8(a, vdup_n_u8(output_activation_max));
- a = vmax_u8(a, vdup_n_u8(output_activation_min));
+ a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
+ a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
vst1_u8(output_ptr + channel, a);
}
#endif
for (; channel < depth; ++channel) {
uint8 a = acc[channel];
- a = std::max<uint8>(a, output_activation_min);
- a = std::min<uint8>(a, output_activation_max);
+ a = std::max<uint8>(a, params.quantized_activation_min);
+ a = std::min<uint8>(a, params.quantized_activation_max);
output_ptr[channel] = static_cast<uint8>(a);
}
}
@@ -4143,53 +3778,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Pool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
// Actually carry out L2 Pool. Code is written in forward mode: we go through
// the input values once, and write to all the pooled regions that it maps to.
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Eigen::VectorXf in_square(in_mat.rows());
Eigen::VectorXf out_count(out_mat.cols());
out_count.setZero();
@@ -4200,15 +3805,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- const int hpad = h + pad_height;
- const int wpad = w + pad_width;
- const int h_start = (hpad < filter_height)
- ? 0
- : (hpad - filter_height) / stride_height + 1;
+ const int hpad = h + params.padding_values.height;
+ const int wpad = w + params.padding_values.width;
+ const int h_start =
+ (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
const int h_end = std::min(hpad / stride_height + 1, output_height);
- const int w_start = (wpad < filter_width)
- ? 0
- : (wpad - filter_width) / stride_width + 1;
+ const int w_start =
+ (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
const int w_end = std::min(wpad / stride_width + 1, output_width);
// pre-compute square
const int in_offset = w + input_width * (h + input_height * b);
@@ -4229,53 +3836,37 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
out_count = out_count.array().inverse();
out_mat =
(out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
-}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
+ }
}
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
+inline void LocalResponseNormalization(
+ const tflite::LocalResponseNormalizationParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
- MatchingFlatSize(input_dims, output_dims);
+ MatchingFlatSize(input_shape, output_shape);
- const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Carry out local response normalization, vector by vector.
// Since the data are stored column major, making row-wise operation
// probably not memory efficient anyway, we do an explicit for loop over
// the columns.
- const int double_range = range * 2;
+ const int double_range = op_params.range * 2;
Eigen::VectorXf padded_square(data_in.rows() + double_range);
padded_square.setZero();
for (int r = 0; r < data_in.cols(); ++r) {
// Do local response normalization for data_in(:, r)
// first, compute the square and store them in buffer for repeated use
- padded_square.block(range, 0, data_in.rows(), 1) =
- data_in.col(r).cwiseProduct(data_in.col(r)) * alpha;
+ padded_square.block(op_params.range, 0, data_in.rows(), 1) =
+ data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
// Then, compute the scale and writes them to data_out
float accumulated_scale = 0;
for (int i = 0; i < double_range; ++i) {
@@ -4283,32 +3874,33 @@ inline void LocalResponseNormalization(const float* input_data,
}
for (int i = 0; i < data_in.rows(); ++i) {
accumulated_scale += padded_square(i + double_range);
- data_out(i, r) = bias + accumulated_scale;
+ data_out(i, r) = op_params.bias + accumulated_scale;
accumulated_scale -= padded_square(i);
}
}
// In a few cases, the pow computation could benefit from speedups.
- if (beta == 1) {
+ if (op_params.beta == 1) {
data_out.array() = data_in.array() * data_out.array().inverse();
- } else if (beta == 0.5) {
+ } else if (op_params.beta == 0.5) {
data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
} else {
- data_out.array() = data_in.array() * data_out.array().pow(-beta);
+ data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
}
}
-inline void Softmax(const float* input_data, const Dims<4>& input_dims,
- float beta, float* output_data,
- const Dims<4>& output_dims) {
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Softmax");
- MatchingFlatSize(input_dims, output_dims);
+ MatchingFlatSize(input_shape, output_shape);
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Compute the exponential first, removing the max coefficient for numerical
// stability.
- out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
+ out_mat =
+ (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
// We are separating out the exp function so that exp can be vectorized.
out_mat = out_mat.array().exp();
// Normalize to get the activations.
@@ -4317,10 +3909,12 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
out_mat.array().rowwise() *= scale;
}
-inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4334,8 +3928,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int b = 0; b < outer_size; ++b) {
const uint8* input_data_ptr = input_data + b * depth;
@@ -4452,7 +4049,7 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
// perform a division by the above-computed sum-of-exponentials.
int32 fixed_sum_of_exps = sum_of_exps.raw();
int headroom_plus_one =
- __builtin_clz(static_cast<uint32>(fixed_sum_of_exps));
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
// This is the number of bits to the left of the binary point above 1.0.
// Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
// no later adjustment will be needed.
@@ -4525,11 +4122,15 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
-inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
const float* block_input_data = input_data + i * depth;
@@ -4556,13 +4157,129 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
}
}
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1_impl(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ // assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
+ // assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ // The reason for accumulating the result with an extra bit of headroom is
+ // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
+ // recip_denom will otherwise introduce an error.
+ static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
+
+ const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1488522236, std::log(2.0));
+ const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
+ const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1518500250, std::sqrt(0.5));
+ const FixedPoint0 one_quarter =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
+
+ const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1057819769,
+ 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
+
+ const FixedPointAccum shifted_quarter =
+ gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
+
+ // Reinterpret the input value as Q0.31, because we will figure out the
+ // required shift "ourselves" instead of using, say, Rescale.
+ FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
+ // z_a_pow_2 = input_integer_bits - z_a_headroom;
+ int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
+ FixedPoint0 r_a_tmp =
+ SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
+ const int32 r_a_raw =
+ SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
+ // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
+ // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
+ // InputIntegerBits - z_b_headroom - 0.25);
+ const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ // z_b is treated like z_a, but premultiplying by sqrt(0.5).
+ FixedPoint0 z_b = z_a * sqrt_half;
+ int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
+ const int32 r_b_raw =
+ SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
+ const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
+ const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
+ std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
+
+ const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
+ FixedPoint0 q = r - sqrt_sqrt_half;
+ q = q + q;
+
+ const FixedPoint0 common_sq = q * q;
+ const FixedPoint0 num = q * r + q * common_sq * alpha_n;
+ const FixedPoint0 denom_minus_one_0 =
+ p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
+ const FixedPoint0 recip_denom =
+ one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
+
+ const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
+ return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
+ num_scaled * recip_denom);
+}
+
+// Minimum output bits to accommodate log of maximum input range. It actually
+// does not matter if one considers, say, [-64,64] or [-64,64).
+//
+// For example, run this through Octave:
+// [0:127; ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
+constexpr int min_log_x_output_bits(int input_bits) {
+ return input_bits > 90
+ ? 7
+ : input_bits > 44
+ ? 6
+ : input_bits > 21
+ ? 5
+ : input_bits > 10
+ ? 4
+ : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ static_assert(
+ OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
+ "Output integer bits must be sufficent to accommodate logs of inputs.");
+ return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
+ InputIntegerBits>(
+ input_val);
+}
+
// Currently just a copy of the reference code.
-inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
+ const int32 input_multiplier = params.input_multiplier;
+ const int32 input_left_shift = params.input_left_shift;
+ const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+ const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4576,8 +4293,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
const uint8* block_input_data = input_data + i * depth;
@@ -4601,13 +4321,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
- // TODO(b/77858996): Implement fixed-point log().
- // Not a fully-quantized implementation: floating-point log().
- const float float_log_sum_of_exps =
- std::log(static_cast<float>(sum_of_exps.raw()) /
- (1 << (31 - kAccumulationIntegerBits)));
- const int32 fixed_log_sum_of_exps = static_cast<int32>(TfLiteRound(
- float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits))));
+ const int32 fixed_log_sum_of_exps =
+ log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
+ sum_of_exps)
+ .raw();
// rescaled_diff_min is smallest representable in
// Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
@@ -4618,9 +4335,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
const int adjusted_diff_min =
std::max(diff_min - 1, // Note use of > below instead of >= above.
- MultiplyByQuantizedMultiplierSmallerThanOne(
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
@@ -4644,21 +4361,33 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() =
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
- const int size = MatchingFlatSize(input_dims, output_dims);
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
+ const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
#ifdef USE_NEON
@@ -4790,10 +4519,11 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
}
@@ -4850,21 +4580,33 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
// Note that this is almost the exact same code as in Logistic().
gemmlowp::ScopedProfilingLabel label("Tanh");
- const int size = MatchingFlatSize(input_dims, output_dims);
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
+ const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
int32_t output_zero_point = 128;
@@ -5005,16 +4747,17 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
- int input_left_shift, int16* output_data,
- const Dims<4>& output_dims) {
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const int16* input_data, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
+ const int input_left_shift = params.input_left_shift;
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
TFLITE_DCHECK_LE(input_left_shift, 1);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
const int16* input_data_ptr = input_data;
@@ -5105,86 +4848,23 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Dequantize");
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
- for (int i = 0; i < flat_size; ++i) {
- int32 val = input_data[i];
- float result = static_cast<float>(scale * (val - zero_point));
- output_data[i] = result;
- }
-}
-
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("FakeQuant");
-
- // 0 should always be a representable value. Let's assume that the initial
- // min,max range contains 0.
- TFLITE_DCHECK_LE(rmin, 0.0f);
- TFLITE_DCHECK_GE(rmax, 0.0f);
- TFLITE_DCHECK_LT(rmin, rmax);
-
- // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
- int quant_min = 0;
- int quant_max = (1 << num_bits) - 1;
- float nudged_min, nudged_max, nudged_scale;
- NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
- &nudged_max, &nudged_scale);
- const float inv_nudged_scale = 1.0f / nudged_scale;
-
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
- for (int i = 0; i < flat_size; ++i) {
- const float src_val = input_data[i];
- const float clamped = std::min(nudged_max, std::max(nudged_min, src_val));
- const float clamped_shifted = clamped - nudged_min;
- const float dst_val =
- TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale +
- nudged_min;
- output_data[i] = dst_val;
- }
-}
-
template <typename SrcT, typename DstT>
-inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
- DstT* output_data, const Dims<4>& output_dims) {
+inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
+ const RuntimeShape& output_shape, DstT* output_data) {
gemmlowp::ScopedProfilingLabel label("Cast");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().template cast<DstT>();
}
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Floor");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = Eigen::floor(input_map.array());
}
-template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Gather");
-
- TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
- int stride = input_dims.strides[input_rank - 1];
- T* out = output_data;
-
- for (int i = 0; i < coords_dims.sizes[0]; i++) {
- TFLITE_DCHECK_GE(coords_data[i], 0);
- TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
- const T* in = input_data + coords_data[i] * stride;
- memcpy(out, in, sizeof(T) * stride);
- out += stride;
- }
-}
-
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
float scale, float* output_ptr) {
@@ -5284,12 +4964,14 @@ inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
int32 x, int32 y, int32 depth, int32 batch,
+ const RuntimeShape& input_shape,
const float* input_data,
- const Dims<4>& input_dims,
- float* output_data,
- const Dims<4>& output_dims) {
- const int32 input_width = ArraySize(input_dims, 1);
- const int32 output_width = ArraySize(output_dims, 1);
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int32 input_width = input_shape.Dims(2);
+ const int32 output_width = output_shape.Dims(2);
const int32 input_x_offset = (x1 - x0) * depth;
const int32 input_y_offset = (y1 - y0) * depth * input_width;
@@ -5297,7 +4979,6 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
const int32 output_y_offset = depth * output_width;
#ifdef USE_NEON
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
TFLITE_DCHECK(x1 >= x0);
TFLITE_DCHECK(y1 >= y0);
@@ -5307,7 +4988,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
const float* input_ptr = nullptr;
float32x4x2_t x0y0;
- input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
x0y0.val[0] = vld1q_f32(input_ptr);
x0y0.val[1] = vld1q_f32(input_ptr + 4);
@@ -5327,7 +5008,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
x1y1.val[1] = vld1q_f32(input_ptr + 4);
// Top left corner.
- float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
vst1q_f32(output_ptr, x0y0.val[0]);
vst1q_f32(output_ptr + 4, x0y0.val[1]);
@@ -5366,14 +5047,15 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
// Handle 4 input channels at a time.
for (; ic <= depth - 4; ic += 4) {
- const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ const float* input_ptr =
+ &input_data[Offset(input_shape, batch, y0, x0, ic)];
float32x4_t x0y0 = vld1q_f32(input_ptr);
float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
// Top left corner.
- float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
vst1q_f32(output_ptr, x0y0);
// Top right corner.
@@ -5397,7 +5079,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
// Handle one input channel at a time.
for (; ic < depth; ic++) {
- const int32 input_offset = Offset(input_dims, ic, x0, y0, batch);
+ const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
float x0y0 = input_data[input_offset];
float x1y0 = input_data[input_offset + input_x_offset];
@@ -5405,7 +5087,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
// Top left corner.
- const int32 output_offset = Offset(output_dims, ic, x, y, batch);
+ const int32 output_offset = Offset(output_shape, batch, y, x, ic);
output_data[output_offset] = x0y0;
// Top right corner.
@@ -5421,7 +5103,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
#else
for (int ch = 0; ch < depth; ch++) {
- const int32 input_offset = Offset(input_dims, ch, x0, y0, batch);
+ const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
float x0y0 = input_data[input_offset];
float x1y0 = input_data[input_offset + input_x_offset];
@@ -5429,7 +5111,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
// Top left corner.
- const int32 output_offset = Offset(output_dims, ch, x, y, batch);
+ const int32 output_offset = Offset(output_shape, batch, y, x, ch);
output_data[output_offset] = x0y0;
// Top right corner.
@@ -5446,31 +5128,30 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
#endif
}
-inline void ResizeBilinear2x2(const float* input_data,
- const Dims<4>& input_dims, float* output_data,
- const Dims<4>& output_dims, int32 batches,
- int32 input_height, int32 input_width,
- int32 depth, int32 output_height,
- int32 output_width) {
+inline void ResizeBilinear2x2(int32 batches, int32 input_height,
+ int32 input_width, int32 depth,
+ int32 output_height, int32 output_width,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
for (int b = 0; b < batches; b++) {
for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
int32 x1 = std::min(x0 + 1, input_width - 1);
int32 y1 = std::min(y0 + 1, input_height - 1);
- ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data,
- input_dims, output_data, output_dims);
+ ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
+ input_data, output_shape, output_data);
}
}
}
}
-inline void ResizeBilinearGeneric(const float* input_data,
- const Dims<4>& input_dims, float* output_data,
- const Dims<4>& output_dims, int32 batches,
- int32 input_height, int32 input_width,
- int32 depth, int32 output_height,
- int32 output_width, float height_scale,
- float width_scale) {
+inline void ResizeBilinearGeneric(
+ int32 batches, int32 input_height, int32 input_width, int32 depth,
+ int32 output_height, int32 output_width, float height_scale,
+ float width_scale, const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
memset(output_data, 0,
batches * output_height * output_width * depth * sizeof(float));
@@ -5487,22 +5168,22 @@ inline void ResizeBilinearGeneric(const float* input_data,
float* output_ptr = &output_data[output_offset];
// Run kernel on the 4 corners of the bilinear resize algorithm.
- int32 input_offset = Offset(input_dims, 0, x0, y0, b);
+ int32 input_offset = Offset(input_shape, b, y0, x0, 0);
float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
const float* input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x1, y0, b);
+ input_offset = Offset(input_shape, b, y0, x1, 0);
scale = (1 - (input_y - y0)) * (input_x - x0);
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x0, y1, b);
+ input_offset = Offset(input_shape, b, y1, x0, 0);
scale = (input_y - y0) * (1 - (input_x - x0));
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x1, y1, b);
+ input_offset = Offset(input_shape, b, y1, x1, 0);
scale = (input_y - y0) * (input_x - x0);
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
@@ -5513,102 +5194,134 @@ inline void ResizeBilinearGeneric(const float* input_data,
}
}
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+template <typename T>
+inline void ResizeBilinearGenericSmallChannel(
+ int32 batches, int32 input_height, int32 input_width, int32 depth,
+ int32 output_height, int32 output_width, float height_scale,
+ float width_scale, const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ memset(output_data, 0,
+ batches * output_height * output_width * depth * sizeof(T));
+
+ T* output_ptr = &output_data[0];
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ float input_y = y * height_scale;
+ int32 y0 = static_cast<int32>(std::floor(input_y));
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ for (int x = 0; x < output_width; ++x) {
+ float input_x = x * width_scale;
+ int32 x0 = static_cast<int32>(input_x);
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+
+ int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
+ Offset(input_shape, b, y0, x1, 0),
+ Offset(input_shape, b, y1, x0, 0),
+ Offset(input_shape, b, y1, x1, 0)};
+ float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
+ (1 - (input_y - y0)) * (input_x - x0),
+ (input_y - y0) * (1 - (input_x - x0)),
+ (input_y - y0) * (input_x - x0)};
+
+ for (int d = 0; d < depth; d++) {
+ const T* input_ptr = &input_data[d];
+ *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
+ input_ptr[input_offset[1]] * scale[1] +
+ input_ptr[input_offset[2]] * scale[2] +
+ input_ptr[input_offset[3]] * scale[3]);
+ }
+ }
+ }
+ }
+}
+
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const float* input_data,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims, bool align_corners) {
+ const RuntimeShape& unextended_output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
// Specialize for 2x2 upsample.
- if (!align_corners && output_height == 2 * input_height &&
+ if (!op_params.align_corners && output_height == 2 * input_height &&
output_width == 2 * input_width) {
- ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches,
- input_height, input_width, depth, output_height,
- output_width);
+ ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
+ output_width, input_shape, input_data, output_shape,
+ output_data);
} else {
float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
- if (align_corners && output_height > 1) {
+ if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
- if (align_corners && output_width > 1) {
+ if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
- ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims,
- batches, input_height, input_width, depth,
+ ResizeBilinearGeneric(batches, input_height, input_width, depth,
output_height, output_width, height_scale,
- width_scale);
+ width_scale, input_shape, input_data, output_shape,
+ output_data);
}
}
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
+// or int16 arithmetic.
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
-}
-
-template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
- // Unoptimized - Straight copy from reference ops.
- gemmlowp::ScopedProfilingLabel label("SpaceToBatchND");
-
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
- const int block_shape_height = block_shape_data[0];
- const int block_shape_width = block_shape_data[1];
- const int padding_top = paddings_data[0];
- const int padding_left = paddings_data[2];
-
- for (int out_b = 0; out_b < output_batch_size; ++out_b) {
- int input_batch = out_b % input_batch_size;
- int shift_w = (out_b / input_batch_size) % block_shape_width;
- int shift_h = (out_b / input_batch_size) / block_shape_width;
- for (int out_h = 0; out_h < output_height; ++out_h) {
- for (int out_w = 0; out_w < output_width; ++out_w) {
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
- if (out_h * block_shape_height + shift_h < padding_top ||
- out_h * block_shape_height + shift_h >=
- padding_top + input_height ||
- out_w * block_shape_width + shift_w < padding_left ||
- out_w * block_shape_width + shift_w >= padding_left + input_width) {
- memset(out, 0, depth * sizeof(T));
- } else {
- const T* in =
- input_data +
- Offset(input_dims, 0,
- (out_w * block_shape_width + shift_w) - padding_left,
- (out_h * block_shape_height + shift_h) - padding_top,
- input_batch);
- memcpy(out, in, depth * sizeof(T));
- }
- }
- }
- }
+ const RuntimeShape& unextended_output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
+
+ float height_scale =
+ (op_params.align_corners && output_height > 1)
+ ? (static_cast<float>(input_height - 1) / (output_height - 1))
+ : (static_cast<float>(input_height) / output_height);
+
+ float width_scale =
+ (op_params.align_corners && output_width > 1)
+ ? (static_cast<float>(input_width - 1) / (output_width - 1))
+ : (static_cast<float>(input_width) / output_width);
+
+ ResizeBilinearGenericSmallChannel<uint8>(
+ batches, input_height, input_width, depth, output_height, output_width,
+ height_scale, width_scale, input_shape, input_data, output_shape,
+ output_data);
}
// Helper methods for BatchToSpaceND.
@@ -5633,20 +5346,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
}
template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -5681,8 +5403,9 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
spatial_offset % block_shape_width - crops_left;
TFLITE_DCHECK_GE(out_w, 0);
TFLITE_DCHECK_LT(out_w, output_width);
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -5705,31 +5428,56 @@ void TypedMemset(void* ptr, T value, size_t num) {
}
}
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
+// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
+// scalar input that provides the padding value. Therefore pad_value_ptr can be
+// equivalent to a simple input1_data. For Pad, it should point to a zero
+// value.
+//
+// Note that two typenames are required, so that T=P=int32 is considered a
+// specialization distinct from P=int32.
+template <typename T, typename P>
+inline void PadImpl(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("Pad");
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
-
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int left_b_padding = left_paddings[3];
- const int left_h_padding = left_paddings[2];
- const int left_w_padding = left_paddings[1];
- const int left_d_padding = left_paddings[0];
-
- const int right_b_padding = right_paddings[3];
- const int right_h_padding = right_paddings[2];
- const int right_w_padding = right_paddings[1];
- const int right_d_padding = right_paddings[0];
-
- const int input_depth = ArraySize(input_dims, 0);
+ const RuntimeShape ext_input_shape =
+ RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+ TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
+ TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
+
+ // Runtime calls are currently fixed at 4 dimensions. Copy inputs so
+ // we can pad them to 4 dims (yes, we are "padding the padding").
+ std::vector<int> left_padding_copy(4, 0);
+ const int left_padding_extend = 4 - op_params.left_padding_count;
+ for (int i = 0; i < op_params.left_padding_count; ++i) {
+ left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
+ }
+ std::vector<int> right_padding_copy(4, 0);
+ const int right_padding_extend = 4 - op_params.right_padding_count;
+ for (int i = 0; i < op_params.right_padding_count; ++i) {
+ right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
+ }
+
+ const int output_batch = ext_output_shape.Dims(0);
+ const int output_height = ext_output_shape.Dims(1);
+ const int output_width = ext_output_shape.Dims(2);
+ const int output_depth = ext_output_shape.Dims(3);
+
+ const int left_b_padding = left_padding_copy[0];
+ const int left_h_padding = left_padding_copy[1];
+ const int left_w_padding = left_padding_copy[2];
+ const int left_d_padding = left_padding_copy[3];
+
+ const int right_b_padding = right_padding_copy[0];
+ const int right_h_padding = right_padding_copy[1];
+ const int right_w_padding = right_padding_copy[2];
+ const int right_d_padding = right_padding_copy[3];
+
+ const int input_depth = ext_input_shape.Dims(3);
+ const T pad_value = *pad_value_ptr;
if (left_b_padding != 0) {
TypedMemset<T>(
@@ -5739,147 +5487,118 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims,
for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
++out_b) {
if (left_h_padding != 0) {
- TypedMemset<T>(output_data + Offset(output_dims, 0, 0, 0, out_b),
+ TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0),
pad_value, left_h_padding * output_width * output_depth);
}
for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
++out_h) {
if (left_w_padding != 0) {
- TypedMemset<T>(output_data + Offset(output_dims, 0, 0, out_h, out_b),
- pad_value, left_w_padding * output_depth);
+ TypedMemset<T>(
+ output_data + Offset(ext_output_shape, out_b, out_h, 0, 0),
+ pad_value, left_w_padding * output_depth);
}
for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
++out_w) {
if (left_d_padding != 0) {
TypedMemset<T>(
- output_data + Offset(output_dims, 0, out_w, out_h, out_b),
+ output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0),
pad_value, left_d_padding);
}
T* out = output_data +
- Offset(output_dims, left_d_padding, out_w, out_h, out_b);
- const T* in =
- input_data + Offset(input_dims, 0, out_w - left_w_padding,
- out_h - left_h_padding, out_b - left_b_padding);
+ Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding);
+ const T* in = input_data +
+ Offset(ext_input_shape, out_b - left_b_padding,
+ out_h - left_h_padding, out_w - left_w_padding, 0);
memcpy(out, in, input_depth * sizeof(T));
if (right_d_padding != 0) {
TypedMemset<T>(
- output_data + Offset(output_dims, output_depth - right_d_padding,
- out_w, out_h, out_b),
+ output_data + Offset(ext_output_shape, out_b, out_h, out_w,
+ output_depth - right_d_padding),
pad_value, right_d_padding);
}
}
if (right_w_padding != 0) {
- TypedMemset<T>(
- output_data + Offset(output_dims, 0, output_width - right_w_padding,
- out_h, out_b),
- pad_value, right_w_padding * output_depth);
+ TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_h,
+ output_width - right_w_padding, 0),
+ pad_value, right_w_padding * output_depth);
}
}
if (right_h_padding != 0) {
TypedMemset<T>(
- output_data +
- Offset(output_dims, 0, 0, output_height - right_h_padding, out_b),
+ output_data + Offset(ext_output_shape, out_b,
+ output_height - right_h_padding, 0, 0),
pad_value, right_h_padding * output_width * output_depth);
}
}
if (right_b_padding != 0) {
TypedMemset<T>(
output_data +
- Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
+ Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0),
pad_value,
right_b_padding * output_height * output_width * output_depth);
}
}
-// Legacy Pad() method that casts an int32_t to T before padding.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
+template <typename T, typename P>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
}
+// The second (pad-value) input can be int32 when, say, the first is uint8.
template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- Pad(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, 0);
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ const T converted_pad_value = static_cast<T>(*pad_value_ptr);
+ PadImpl(op_params, input_shape, input_data, &converted_pad_value,
+ output_shape, output_data);
}
-// UNOPTIMIZED COPY of StridedSlice from reference_ops.h.
-template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
- const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 3);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
- const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 2);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
- const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 1);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
- const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 0);
-
- T* out_ptr = output_data;
- for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
- for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
- for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
- }
- }
- }
- }
+// This version avoids conflicting template matching.
+template <>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const int32* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ int32* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- // TODO(dkalenichenko): This op only supports 4D tensors.
- TFLITE_DCHECK_EQ(begin.size(), 4);
- TFLITE_DCHECK_EQ(size.size(), 4);
- const int start_b = begin[3];
- const int stop_b =
- size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
- const int start_h = begin[2];
- const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
- const int start_w = begin[1];
- const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
- const int start_d = begin[0];
- const int stop_d =
- size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+inline void Slice(const tflite::SliceParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Slice");
+ const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
+ TFLITE_DCHECK_LE(op_params.begin_count, 4);
+ TFLITE_DCHECK_LE(op_params.size_count, 4);
+ const int begin_count = op_params.begin_count;
+ const int size_count = op_params.size_count;
+ // We front-pad the begin and size vectors.
+ const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
+ const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
+ ? ext_shape.Dims(0) - start_b
+ : start_b + op_params.size[0];
+ const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
+ const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
+ ? ext_shape.Dims(1) - start_h
+ : start_h + op_params.size[size_count - 3];
+ const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
+ const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
+ ? ext_shape.Dims(2) - start_w
+ : start_w + op_params.size[size_count - 2];
+ const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
+ const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
+ ? ext_shape.Dims(3) - start_d
+ : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
@@ -5887,7 +5606,7 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
for (int in_w = start_w; in_w < stop_w; ++in_w) {
const int len = stop_d - start_d;
memcpy(out_ptr,
- input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
+ input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
len * sizeof(T));
out_ptr += len;
}
@@ -5896,243 +5615,112 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mean");
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
-
- // The current implementation only supports simultaneous reduction over
- // width and height.
- TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
- TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
- (reduction_indices[0] == 2 && reduction_indices[1] == 1));
- TFLITE_DCHECK_EQ(output_height, 1);
- TFLITE_DCHECK_EQ(output_width, 1);
-
- for (int out_b = 0; out_b < output_batch; ++out_b) {
- for (int out_d = 0; out_d < output_depth; ++out_d) {
- float value = 0;
- for (int in_h = 0; in_h < input_height; ++in_h) {
- for (int in_w = 0; in_w < input_width; ++in_w) {
- value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
- }
- }
- output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
- value / (input_width * input_height);
- }
- }
-}
-
-template <typename T>
-void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- }
- }
- }
- }
-}
-
-template <typename T>
-void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
- const Dims<4>& input2_dims, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Sub");
-
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
- output_map.array() = input1_map.array() - input2_map.array();
- } else if (FlatSize(input1_dims) == 1) {
- auto scalar = input1_data[0];
- output_map.array() = scalar - input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
- auto scalar = input2_data[0];
- output_map.array() = input1_map.array() - scalar;
- } else {
- GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims,
- output_data, output_dims);
- }
-}
-
-template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
+void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
auto min_value = input2_data[0];
output_map.array() = input1_map.array().min(min_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
+template <typename T>
+void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
auto max_value = input2_data[0];
output_map.array() = input1_map.array().max(max_value);
}
-template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("ArgMax");
-
- // The current ArgMax implemention can only determine the index of the maximum
- // value in the last dimension. So the axis argument is ignored.
-
- // For ArgMax, the number of output dimensions = (number of input dimensions -
- // 1). For the sake of simplicity, the output dimensions are equal to the
- // input dimensions here. We enforce the constraint that the last dimension
- // must always be 1.
- TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = ArraySize(input_dims, 0);
- for (int i = 0; i < outer_size; ++i) {
- auto max_value = *input_data;
- ++input_data;
- int max_index = 0;
- for (int d = 1; d < depth; ++d) {
- const auto& curr_value = *input_data;
- if (curr_value > max_value) {
- max_value = curr_value;
- max_index = d;
- }
- ++input_data;
- }
- *output_data = max_index;
- ++output_data;
- }
-}
-
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, const int* permuted_axes) {
- int out_sizes[4];
- // Compute the inverse permutation array so we can do an output centered
- // transpose. Also, check to make sure output_dims is matching input_dims.
- for (int k = 0; k < 4; k++) {
- out_sizes[k] =
- MatchingArraySize(input_dims, permuted_axes[k], output_dims, k);
- }
-
- // Naive transpose loop (iterate on output index and compute input index).
- int o[4]; // loop index (on output).
- int i[4];
- for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
- i[permuted_axes[3]] = o[3];
- for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
- i[permuted_axes[2]] = o[2];
- for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
- i[permuted_axes[1]] = o[1];
- for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
- i[permuted_axes[0]] = o[0];
- output[Offset(output_dims, o)] = input[Offset(input_dims, i)];
- }
- }
- }
- }
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("TransposeConv");
- // THIS FUNCTION IS A COPY FROM reference_ops.h.
- // To optimize, start by using the conv code with transposed weights for the
- // case of stride_height = stride_width = 1.
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
-
- // Although transpose convolution simplifies to convolution with transposed
- // weights for strides of 1, non-unitary striding complicates matters. To
- // keep this reference implementation as clear as possible, we use a "scatter"
- // access pattern, where we loop through all the input elements, computing
- // their influence on the output, rather than looping through the output
- // elements in the typical "gather" access pattern of a conv. We therefore
- // must initialize the output array to zero.
- for (int batch = 0; batch < batches; ++batch) {
- for (int out_y = 0; out_y < output_height; ++out_y) {
- for (int out_x = 0; out_x < output_width; ++out_x) {
- for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
- 0.0f;
- }
- }
- }
- }
-
- // Loop through input elements one at a time.
+template <typename T>
+void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
+ gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK(im2col_data);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 0);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 3); // output_depth
+
+ // Construct the MxN sized im2col matrix.
+ // The rows M, are sub-ordered B x H x W
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
+ // The columns, N, are sub-ordered Kh x Kw x Din
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
+ // Use dimensions M and N to construct dims for indexing directly into im2col
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
+
+ // Build the im2col matrix by looping through all the input pixels,
+ // computing their influence on the output, rather than looping through all
+ // the output pixels. We therefore must initialize the im2col array to zero.
+ // This is potentially inefficient because we subsequently overwrite bytes
+ // set here. However, in practice memset is very fast and costs negligible.
+ memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
+
+ // Loop through the output batches
for (int batch = 0; batch < batches; ++batch) {
+ // Loop through input pixels one at a time.
for (int in_y = 0; in_y < input_height; ++in_y) {
for (int in_x = 0; in_x < input_width; ++in_x) {
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- // Loop through the output elements it will influence
- const int out_x_origin = (in_x * stride_width) - pad_width;
- const int out_y_origin = (in_y * stride_height) - pad_height;
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ // Loop through the output pixels it will influence
+ const int out_x_origin = (in_x * stride_width) - pad_width;
+ const int out_y_origin = (in_y * stride_height) - pad_height;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int out_y = out_y_origin + filter_y;
+ // Is output pixel within height bounds?
+ if ((out_y >= 0) && (out_y < output_height)) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- for (int out_channel = 0; out_channel < output_depth;
- ++out_channel) {
- // Compute output element location
- const int out_x = out_x_origin + filter_x;
- const int out_y = out_y_origin + filter_y;
- // We cannot accumulate out of bounds
- if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
- (out_y < output_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
- float filter_value =
- filter_data[Offset(filter_dims, out_channel, filter_x,
- filter_y, in_channel)];
- output_data[Offset(output_dims, out_channel, out_x, out_y,
- batch)] += input_value * filter_value;
- }
+ const int out_x = out_x_origin + filter_x;
+ // Is output pixel within width bounds?
+ if ((out_x >= 0) && (out_x < output_width)) {
+ // Copy the input elements of this pixel
+ T const* src =
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
+ T* dst = im2col_data +
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
+ memcpy(dst, src, input_depth * sizeof(T));
}
}
}
@@ -6142,6 +5730,29 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+ gemmlowp::ScopedProfilingLabel label("TransposeConv");
+
+ // Note we could use transposed weights with forward conv for unstrided
+ // cases. But we are already getting good performance with this code as-is.
+ TFLITE_DCHECK(im2col_data);
+ TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
+ output_shape, im2col_data);
+
+ const auto im2col_matrix_map =
+ MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
+ const auto filter_matrix_map =
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
+ auto output_matrix_map =
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
+
+ Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+}
+
} // namespace optimized_ops
} // namespace tflite
@@ -6150,4 +5761,4 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
#pragma GCC diagnostic pop
#endif
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index d570dadd86..f87760a6c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -17,7 +17,11 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
#ifndef USE_NEON
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
@@ -82,6 +86,14 @@ void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -105,6 +117,10 @@ void PortableClipVector(const float* vector, int v_size, float abs_limit,
void NeonClipVector(const float* vector, int v_size, float abs_limit,
float* result);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
@@ -124,9 +140,19 @@ void PortableCopyVector(const float* vector, int v_size, float* result);
// Fill vector with 0.f.
void PortableZeroVector(float* vector, int v_size);
+// Multiply all elements of vector with a scalar.
+void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+void NeonVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
// Limit a float input f between +abs_limit and -abs_limit.
float PortableClip(float f, float abs_limit);
+// Check if all entries of a vector are zero.
+bool PortableIsZeroVector(const float* vector, int v_size);
+bool NeonIsZeroVector(const float* vector, int v_size);
+
// Symmetric quantizer.
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
@@ -150,6 +176,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
void NeonReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
index b0951aac8c..544ef16ce1 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+
#include <algorithm>
#include <cmath>
#include <limits>
@@ -22,6 +23,32 @@ limitations under the License.
namespace tflite {
+namespace {
+// These constants are used to manipulate the binary representation of doubles.
+// Double-precision binary64 floating point format is:
+// Bit | 63 | 62-52 | 51-0 |
+// | Sign | Exponent | Fraction |
+// To avoid 64-bit integers as much as possible, I break this into high and
+// low 32-bit chunks. High is:
+// Bit | 31 | 30-20 | 19-0 |
+// | Sign | Exponent | High Fraction |
+// Low is:
+// Bit | 31-0 |
+// | Low Fraction |
+// We then access the components through logical bit-wise operations to
+// extract the parts needed, with the positions and masks derived from the
+// layout shown above.
+constexpr uint64_t kSignMask = 0x8000000000000000LL;
+constexpr uint64_t kExponentMask = 0x7ff0000000000000LL;
+constexpr int32_t kExponentShift = 52;
+constexpr int32_t kExponentBias = 1023;
+constexpr uint32_t kExponentIsBadNum = 0x7ff;
+constexpr uint64_t kFractionMask = 0x000fffffffc00000LL;
+constexpr uint32_t kFractionShift = 22;
+constexpr uint32_t kFractionRoundingMask = 0x003fffff;
+constexpr uint32_t kFractionRoundingThreshold = 0x00200000;
+} // namespace
+
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift) {
if (double_multiplier == 0.) {
@@ -29,8 +56,16 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
*shift = 0;
return;
}
+#ifdef TFLITE_EMULATE_FLOAT
+ // If we're trying to avoid the use of floating-point instructions (for
+ // example on microcontrollers) then use an alternative implementation
+ // that only requires integer and bitwise operations. To enable this, you
+ // need to set the define during the build process for your platform.
+ int64_t q_fixed = IntegerFrExp(double_multiplier, shift);
+#else // TFLITE_EMULATE_FLOAT
const double q = std::frexp(double_multiplier, shift);
auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+#endif // TFLITE_EMULATE_FLOAT
TFLITE_CHECK(q_fixed <= (1ll << 31));
if (q_fixed == (1ll << 31)) {
q_fixed /= 2;
@@ -48,15 +83,172 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier,
TFLITE_CHECK_GE(*left_shift, 0);
}
-void QuantizeMultiplierSmallerThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
- int* right_shift) {
+void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift) {
TFLITE_CHECK_LT(double_multiplier, 1.);
TFLITE_CHECK_GT(double_multiplier, 0.);
int shift;
QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift);
TFLITE_CHECK_LE(shift, 0);
- *right_shift = -shift;
+ *left_shift = shift;
+}
+
+int64_t IntegerFrExp(double input, int* shift) {
+ // Make sure our assumptions about the double layout hold.
+ TFLITE_CHECK_EQ(8, sizeof(double));
+
+ // We want to access the bits of the input double value directly, which is
+ // tricky to do safely, so use a union to handle the casting.
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } cast_union;
+ cast_union.double_value = input;
+ const uint64_t u = cast_union.double_as_uint;
+
+ // If the bitfield is all zeros apart from the sign bit, this is a normalized
+ // zero value, so return standard values for this special case.
+ if ((u & ~kSignMask) == 0) {
+ *shift = 0;
+ return 0;
+ }
+
+ // Deal with NaNs and Infs, which are always indicated with a fixed pattern in
+ // the exponent, and distinguished by whether the fractions are zero or
+ // non-zero.
+ const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift);
+ if (exponent_part == kExponentIsBadNum) {
+ *shift = std::numeric_limits<int>::max();
+ if (u & kFractionMask) {
+ // NaN, so just return zero (with the exponent set to INT_MAX).
+ return 0;
+ } else {
+ // Infinity, so return +/- INT_MAX.
+ if (u & kSignMask) {
+ return std::numeric_limits<int64_t>::min();
+ } else {
+ return std::numeric_limits<int64_t>::max();
+ }
+ }
+ }
+
+ // The shift is fairly easy to extract from the high bits of the double value,
+ // just by masking it out and applying a bias. The std::frexp() implementation
+ // always returns values between 0.5 and 1.0 though, whereas the exponent
+ // assumes 1.0 to 2.0 is the standard range, so I add on one to match that
+ // interface.
+ *shift = (exponent_part - kExponentBias) + 1;
+
+ // There's an implicit high bit in the double format definition, so make sure
+ // we include that at the top, and then reconstruct the rest of the fractional
+ // value from the remaining fragments.
+ int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift);
+
+ // We're cutting off some bits at the bottom, so to exactly match the standard
+ // frexp implementation here we'll apply rounding by adding one to the least
+ // significant bit of the result if the discarded portion is over half of the
+ // maximum.
+ if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) {
+ fraction += 1;
+ }
+ // Negate the fraction if the sign bit was set.
+ if (u & kSignMask) {
+ fraction *= -1;
+ }
+
+ return fraction;
+}
+
+double DoubleFromFractionAndShift(int64_t fraction, int shift) {
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } result;
+
+ // Detect NaNs and infinities.
+ if (shift == std::numeric_limits<int>::max()) {
+ if (fraction == 0) {
+ return NAN;
+ } else if (fraction > 0) {
+ return INFINITY;
+ } else {
+ return -INFINITY;
+ }
+ }
+
+ // Return a normalized zero for a zero fraction.
+ if (fraction == 0) {
+ result.double_as_uint = 0;
+ return result.double_value;
+ }
+
+ bool is_negative = (fraction < 0);
+ int64_t encoded_fraction = is_negative ? -fraction : fraction;
+ int64_t encoded_shift = (shift - 1);
+ while (encoded_fraction < 0x40000000) {
+ encoded_fraction *= 2;
+ encoded_shift -= 1;
+ }
+ while (encoded_fraction > 0x80000000) {
+ encoded_fraction /= 2;
+ encoded_shift += 1;
+ }
+ encoded_fraction -= 0x40000000;
+ if (encoded_shift < -1022) {
+ encoded_shift = -1023;
+ } else if (encoded_shift > 1022) {
+ encoded_shift = 1023;
+ }
+ encoded_shift += kExponentBias;
+ uint64_t encoded_sign = is_negative ? kSignMask : 0;
+ result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) |
+ (encoded_fraction << kFractionShift);
+ return result.double_value;
+}
+
+double IntegerDoubleMultiply(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return NAN;
+ }
+ const int result_shift = a_shift + b_shift + 1;
+ const int64_t result_fraction = (a_fraction * b_fraction) >> 32;
+ return DoubleFromFractionAndShift(result_fraction, result_shift);
+}
+
+int IntegerDoubleCompare(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return 1;
+ }
+
+ if ((a_fraction == 0) && (b_fraction < 0)) {
+ return 1;
+ } else if ((a_fraction < 0) && (b_fraction == 0)) {
+ return -1;
+ } else if (a_shift < b_shift) {
+ return -1;
+ } else if (a_shift > b_shift) {
+ return 1;
+ } else if (a_fraction < b_fraction) {
+ return -1;
+ } else if (a_fraction > b_fraction) {
+ return 1;
+ } else {
+ return 0;
+ }
}
void PreprocessSoftmaxScaling(double beta, double input_scale,
@@ -71,30 +263,49 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
// result is double equivalent of Q0.31 (actually with more precision). Thus
// this generates a Q(input_integer_bits).(31-input_integer_bits)
// representation.
+#ifdef TFLITE_EMULATE_FLOAT
+ const double input_beta = IntegerDoubleMultiply(beta, input_scale);
+ int shift;
+ int64_t fraction = IntegerFrExp(input_beta, &shift);
+ shift += (31 - input_integer_bits);
+ double input_beta_real_multiplier =
+ DoubleFromFractionAndShift(fraction, shift);
+ if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) {
+ input_beta_real_multiplier = (1ll << 31) - 1.0;
+ }
+#else // TFLITE_EMULATE_FLOAT
const double input_beta_real_multiplier = std::min(
beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+#endif // TFLITE_EMULATE_FLOAT
QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
quantized_multiplier, left_shift);
}
-void PreprocessLogSoftmaxScaling(double beta, double input_scale,
- int input_integer_bits,
- int32_t* quantized_multiplier, int* left_shift,
- int32_t* reverse_scaling_divisor,
- int* reverse_scaling_right_shift) {
+void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier,
+ int* left_shift,
+ int32_t* reverse_scaling_divisor,
+ int* reverse_scaling_left_shift) {
PreprocessSoftmaxScaling(beta, input_scale, input_integer_bits,
quantized_multiplier, left_shift);
// Also calculate what amounts to the inverse scaling factor for the input.
const double real_reverse_scaling_divisor =
(1 << (31 - *left_shift)) / static_cast<double>(*quantized_multiplier);
- tflite::QuantizeMultiplierSmallerThanOne(real_reverse_scaling_divisor,
- reverse_scaling_divisor,
- reverse_scaling_right_shift);
+ tflite::QuantizeMultiplierSmallerThanOneExp(real_reverse_scaling_divisor,
+ reverse_scaling_divisor,
+ reverse_scaling_left_shift);
}
int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+#ifdef TFLITE_EMULATE_FLOAT
+ int64_t result = (1 << input_integer_bits) - 1;
+ result <<= (31 - input_integer_bits);
+ result >>= input_left_shift;
+ return result;
+#else // TFLITE_EMULATE_FLOAT
const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
(1ll << (31 - input_integer_bits)) /
(1ll << input_left_shift);
@@ -102,17 +313,18 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
// After scaling the difference, the result would be at the maximum. Thus we
// must ensure that our value has lower magnitude.
return static_cast<int>(std::floor(max_input_rescaled));
+#endif // TFLITE_EMULATE_FLOAT
}
void NudgeQuantizationRange(const float min, const float max,
const int quant_min, const int quant_max,
float* nudged_min, float* nudged_max,
- float* scale) {
+ float* nudged_scale) {
// This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h.
const float quant_min_float = static_cast<float>(quant_min);
const float quant_max_float = static_cast<float>(quant_max);
- *scale = (max - min) / (quant_max_float - quant_min_float);
- const float zero_point_from_min = quant_min_float - min / *scale;
+ *nudged_scale = (max - min) / (quant_max_float - quant_min_float);
+ const float zero_point_from_min = quant_min_float - min / *nudged_scale;
uint16 nudged_zero_point;
if (zero_point_from_min < quant_min_float) {
nudged_zero_point = static_cast<uint16>(quant_min);
@@ -121,8 +333,37 @@ void NudgeQuantizationRange(const float min, const float max,
} else {
nudged_zero_point = static_cast<uint16>(TfLiteRound(zero_point_from_min));
}
- *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
- *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
+ *nudged_min = (quant_min_float - nudged_zero_point) * (*nudged_scale);
+ *nudged_max = (quant_max_float - nudged_zero_point) * (*nudged_scale);
+}
+
+void FakeQuantizeArray(const float nudged_scale, const float nudged_min,
+ const float nudged_max, const float* input_data,
+ float* output_data, const float size) {
+ // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h.
+ const float inv_nudged_scale = 1.0f / nudged_scale;
+
+ for (int i = 0; i < size; i++) {
+ const float src_val = input_data[i];
+ const float clamped = std::min(nudged_max, std::max(nudged_min, src_val));
+ const float clamped_shifted = clamped - nudged_min;
+ const float dst_val =
+ TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale +
+ nudged_min;
+ output_data[i] = dst_val;
+ }
+}
+
+bool CheckedLog2(const float x, int* log2_result) {
+ // Using TfLiteRound instead of std::round and std::log instead of
+ // std::log2 to work around these fuctions being missing in a toolchain
+ // used in some TensorFlow tests as of May 2018.
+ const float x_log2 = std::log(x) * (1.0f / std::log(2.0f));
+ const float x_log2_rounded = TfLiteRound(x_log2);
+ const float x_log2_fracpart = x_log2 - x_log2_rounded;
+
+ *log2_result = static_cast<int>(x_log2_rounded);
+ return std::abs(x_log2_fracpart) < 1e-3;
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 4a217515f1..d74a1bac97 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -28,8 +28,9 @@ namespace tflite {
// Given the min and max values of a float array, return
// reasonable quantization parameters to use for this array.
template <typename T>
-QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
- const T qmin = std::numeric_limits<T>::min();
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax,
+ bool narrow_range) {
+ const T qmin = std::numeric_limits<T>::min() + (narrow_range ? 1 : 0);
const T qmax = std::numeric_limits<T>::max();
const double qmin_double = qmin;
const double qmax_double = qmax;
@@ -97,6 +98,11 @@ QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
return quantization_params;
}
+template <typename T>
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
+ return ChooseQuantizationParams<T>(rmin, rmax, false);
+}
+
// Converts a floating-point number to an integer. For all inputs x where
// static_cast<IntOut>(x) is legal according to the C++ standard, the result
// is identical to that cast (i.e. the result is x with its fractional part
@@ -167,9 +173,9 @@ IntOut SafeCast(FloatIn x) {
// this is intended as a RIGHT-shift.
//
// Restricted to the case where the multiplier < 1 (and non-negative).
-void QuantizeMultiplierSmallerThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
- int* right_shift);
+void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift);
// Decompose a double multiplier into a Q0.31 int32 representation of its
// significand, and shift representation of its exponent.
@@ -189,6 +195,44 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier,
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift);
+// Splits a double input value into a returned fraction, and a shift value from
+// the exponent, using only bitwise and integer operations to support
+// microcontrollers and other environments without floating-point support.
+//
+// This is designed to be a replacement for how std::frexp() is used within the
+// QuantizeMultiplier() function, and so has a different signature than the
+// standard version, returning a 64-bit integer rather than a double. This
+// result has a maximum value of 1<<31, with the fraction expressed as a
+// proportion of that maximum.
+//
+// std::frexp() returns NaNs and infinities unmodified, but since we're
+// returning integers that can't represent those values, instead we return
+// a shift of std::numeric_limits<int>::max() for all bad numbers, with an int64
+// result of 0 for NaNs, std:numeric_limits<int64_t>::max() for +INFINITY, and
+// std::numeric_limits<int64_t>::min() for -INFINITY. Denormalized inputs will
+// result in return values that end up truncating some bits at the end,
+// reflecting the loss of precision inherent in denormalization.
+int64_t IntegerFrExp(double input, int* shift);
+
+// Converts an integer fraction in the format produced by IntegerFrExp (where
+// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an
+// IEEE binary64 double format result. The implementation uses only integer and
+// bitwise operators, so no floating point hardware support or emulation is
+// needed. This is here so quantized operations can run non-time-critical
+// preparation calculations on microcontrollers and other platforms without
+// float support.
+double DoubleFromFractionAndShift(int64_t fraction, int shift);
+
+// Performs a multiplication of two numbers in double format, using only integer
+// and bitwise instructions. This is aimed at supporting housekeeping functions
+// for quantized operations on microcontrollers without floating-point hardware.
+double IntegerDoubleMultiply(double a, double b);
+
+// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is
+// greater than b. It is implemented using only integer and logical instructions
+// so that it can be easily run on microcontrollers for quantized operations.
+int IntegerDoubleCompare(double a, double b);
+
// This first creates a multiplier in a double equivalent of
// Q(input_integer_bits).(31-input_integer_bits) representation, with extra
// precision in the double's fractional bits. It then splits the result into
@@ -197,11 +241,12 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
int input_integer_bits,
int32_t* quantized_multiplier, int* left_shift);
// Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated.
-void PreprocessLogSoftmaxScaling(double beta, double input_scale,
- int input_integer_bits,
- int32_t* quantized_multiplier, int* left_shift,
- int32_t* reverse_scaling_divisor,
- int* reverse_scaling_right_shift);
+void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier,
+ int* left_shift,
+ int32_t* reverse_scaling_divisor,
+ int* reverse_scaling_left_shift);
// Calculate the largest input that will result in a within-bounds intermediate
// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words,
// it must not overflow before we reduce the value by multiplication by the
@@ -215,7 +260,20 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift);
// Outputs nudged_min, nudged_max, nudged_scale.
void NudgeQuantizationRange(const float min, const float max,
const int quant_min, const int quant_max,
- float* nudged_min, float* nudged_max, float* scale);
+ float* nudged_min, float* nudged_max,
+ float* nudged_scale);
+
+// Fake quantizes (quantizes and dequantizes) input_data using the scale,
+// nudged_min, and nudged_max from NudgeQuantizationRange. This matches the code
+// in TensorFlow's FakeQuantizeWithMinMaxVarsFunctor.
+void FakeQuantizeArray(const float nudged_scale, const float nudged_min,
+ const float nudged_max, const float* input_data,
+ float* output_data, const float size);
+
+// If x is approximately a power of two (with any positive or negative
+// exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise
+// returns false.
+bool CheckedLog2(const float x, int* log2_result);
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 2d74b3d384..25ea72b886 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -130,22 +130,22 @@ void RunSafeCastTests() {
}
TEST(QuantizationUtilTest, SafeCast) {
- RunSafeCastTests<float, int8>();
- RunSafeCastTests<double, int8>();
- RunSafeCastTests<float, int16>();
- RunSafeCastTests<double, int16>();
- RunSafeCastTests<float, int32>();
- RunSafeCastTests<double, int32>();
- RunSafeCastTests<float, int64>();
- RunSafeCastTests<double, int64>();
- RunSafeCastTests<float, uint8>();
- RunSafeCastTests<double, uint8>();
- RunSafeCastTests<float, uint16>();
- RunSafeCastTests<double, uint16>();
- RunSafeCastTests<float, uint32>();
- RunSafeCastTests<double, uint32>();
- RunSafeCastTests<float, uint64>();
- RunSafeCastTests<double, uint64>();
+ RunSafeCastTests<float, int8_t>();
+ RunSafeCastTests<double, int8_t>();
+ RunSafeCastTests<float, int16_t>();
+ RunSafeCastTests<double, int16_t>();
+ RunSafeCastTests<float, int32_t>();
+ RunSafeCastTests<double, int32_t>();
+ RunSafeCastTests<float, int64_t>();
+ RunSafeCastTests<double, int64_t>();
+ RunSafeCastTests<float, uint8_t>();
+ RunSafeCastTests<double, uint8_t>();
+ RunSafeCastTests<float, uint16_t>();
+ RunSafeCastTests<double, uint16_t>();
+ RunSafeCastTests<float, uint32_t>();
+ RunSafeCastTests<double, uint32_t>();
+ RunSafeCastTests<float, uint64_t>();
+ RunSafeCastTests<double, uint64_t>();
}
// Example taken from http://www.tensorflow.org/performance/quantization
@@ -191,26 +191,159 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
EXPECT_EQ(qp.zero_point, 255);
}
+TEST(QuantizationUtilTest, IntegerFrExp) {
+ int shift;
+ int64_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(0, result);
+ EXPECT_EQ(0, shift);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(-1, shift);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(-(1 << 30), result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(2071147315, result, 1);
+ EXPECT_EQ(7, shift);
+
+ result = IntegerFrExp(NAN, &shift);
+ EXPECT_NEAR(0, result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(-INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+}
+
+TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
+ int shift;
+ int32_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(result, 0);
+ EXPECT_EQ(shift, 0);
+
+ int double_shift;
+ double double_result = std::frexp(0.0, &double_shift);
+ EXPECT_EQ(double_result, 0);
+ EXPECT_EQ(double_shift, 0);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(1.0, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, -1);
+ double_result = std::frexp(0.25, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, -1);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(result, -(1 << 30), 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(-1.0, &double_shift);
+ EXPECT_NEAR(double_result, -0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(result, (0.964453 * (1LL << 31)), 1000);
+ EXPECT_EQ(shift, 7);
+ double_result = std::frexp(123.45, &double_shift);
+ EXPECT_NEAR(double_result, 0.964453, 1e-5);
+ EXPECT_EQ(double_shift, 7);
+}
+
+TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
+ double result = DoubleFromFractionAndShift(0, 0);
+ EXPECT_EQ(0, result);
+
+ result = DoubleFromFractionAndShift(0x40000000, 1);
+ EXPECT_NEAR(1.0, result, 1e-5);
+
+ result = DoubleFromFractionAndShift(0x40000000, 2);
+ EXPECT_NEAR(2.0, result, 1e-5);
+
+ int shift;
+ int64_t fraction = IntegerFrExp(3.0, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(3.0, result, 1e-5);
+
+ fraction = IntegerFrExp(123.45, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(123.45, result, 1e-5);
+
+ fraction = IntegerFrExp(-23.232323, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(-23.232323, result, 1e-5);
+
+ fraction = IntegerFrExp(NAN, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_TRUE(std::isnan(result));
+
+ fraction = IntegerFrExp(INFINITY, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_FALSE(std::isfinite(result));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
+ EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
+ EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
+ EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
+ EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleCompare) {
+ EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
+ EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
+ EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
+}
+
#ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
}
-TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) {
+TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) {
auto quantize = [](double d) {
int32_t q;
int s;
- QuantizeMultiplierSmallerThanOne(d, &q, &s);
+ QuantizeMultiplierSmallerThanOneExp(d, &q, &s);
return std::pair<int32_t, int>{q, s};
};
EXPECT_DEATH(quantize(-0.1), "");
EXPECT_DEATH(quantize(0.0), "");
- EXPECT_THAT(quantize(0.25), Pair(1073741824, 1));
+ EXPECT_THAT(quantize(0.25), Pair(1073741824, -1));
// Around 0.5 we can see the change in exponent and how we try hard to
// void hitting max int32.
- EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1));
+ EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1));
EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0));
EXPECT_THAT(quantize(0.50), Pair(1073741824, 0));
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 9aabee5000..11224270a4 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -22,24 +22,36 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -52,25 +64,26 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float total = 0.f;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
float input_value =
- input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
float filter_value = filter_data[Offset(
- filter_dims, oc, filter_x, filter_y, 0)];
+ filter_shape, 0, filter_y, filter_x, oc)];
total += (input_value * filter_value);
}
}
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ bias_value = bias_data[oc];
}
- output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
ActivationFunctionWithMinMax(total + bias_value,
output_activation_min,
output_activation_max);
@@ -81,34 +94,6 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- float* output_data, const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, pad_width, pad_height,
- depth_multiplier, output_data, output_dims);
-}
-
} // end namespace reference_ops
} // end namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index e9b6baeaee..eab28e6c84 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include "fixedpoint/fixedpoint.h"
-#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -26,26 +25,42 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -58,30 +73,31 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
int32 input_val =
- input_data[Offset(input_dims, ic, in_x, in_y, b)];
- int32 filter_val = filter_data[Offset(filter_dims, oc,
- filter_x, filter_y, 0)];
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
+ int32 filter_val = filter_data[Offset(
+ filter_shape, 0, filter_y, filter_x, oc)];
acc +=
(filter_val + filter_offset) * (input_val + input_offset);
}
}
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ acc += bias_data[oc];
}
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(
- acc, output_multiplier, output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
static_cast<uint8>(acc);
}
}
@@ -90,48 +106,6 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims, stride,
- stride, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
} // end namespace reference_ops
} // end namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
new file mode 100644
index 0000000000..3c7fd29256
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
@@ -0,0 +1,326 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+const int kReverseShift = -1;
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dims_count = output_shape.DimensionsCount();
+ const int weights_dims_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
+ output_shape, output_dims_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ float total = 0.f;
+ for (int d = 0; d < accum_depth; ++d) {
+ total += input_data[b * accum_depth + d] *
+ weights_data[out_c * accum_depth + d];
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[out_c];
+ }
+ output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
+ total + bias_value, output_activation_min, output_activation_max);
+ }
+ }
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ int32 acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32 input_val = input_data[b * accum_depth + d];
+ int32 filter_val = filter_data[out_c * accum_depth + d];
+ acc += (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ if (bias_data) {
+ acc += bias_data[out_c];
+ }
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
+ }
+ }
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_EQ(output_offset, 0);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum = bias_data[out_c];
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; ++d) {
+ int16 input_val = input_data[b * accum_depth + d] + input_offset;
+ int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset;
+ accum += filter_val * input_val;
+ }
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ accum =
+ MultiplyByQuantizedMultiplier(accum, output_multiplier, output_shift);
+ // Saturate, cast to int16, and store to output array.
+ accum = std::max(accum, output_activation_min - output_offset);
+ accum = std::min(accum, output_activation_max - output_offset);
+ accum += output_offset;
+ output_data[out_c + output_depth * b] = accum;
+ }
+ }
+}
+
+inline void ShuffledFullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
+ TFLITE_DCHECK((accum_depth % 16) == 0);
+ TFLITE_DCHECK((output_depth % 4) == 0);
+
+ // Shuffling and xoring of input activations into the workspace buffer
+ uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
+ if (batches == 1) {
+ for (int i = 0; i < accum_depth; i++) {
+ shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
+ }
+ } else if (batches == 4) {
+ for (int c = 0; c < accum_depth; c += 16) {
+ for (int b = 0; b < 4; b++) {
+ const uint8* src_data_ptr = input_data + b * accum_depth + c;
+ for (int j = 0; j < 16; j++) {
+ uint8 src_val = *src_data_ptr++;
+ // Flip the sign bit, so that the kernel will only need to
+ // reinterpret these uint8 values as int8, getting for free the
+ // subtraction of the zero_point value 128.
+ uint8 dst_val = src_val ^ 0x80;
+ *shuffled_input_workspace_ptr++ = dst_val;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+
+ // Actual computation
+ if (batches == 1) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4] = {0};
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_data[d + j];
+ int8 weights_val = *shuffled_weights_ptr++;
+ accum[i] += weights_val * input_val;
+ }
+ }
+ }
+ for (int i = 0; i < 4; i++) {
+ // Add bias value
+ int32 acc = accum[i] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[c + i] = acc;
+ }
+ }
+ } else if (batches == 4) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ const int8* shuffled_input_ptr = shuffled_input_data;
+ // Accumulation loop.
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ accum[i][b] = 0;
+ }
+ }
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_ptr[16 * b + j];
+ int8 weights_val = shuffled_weights_ptr[16 * i + j];
+ accum[i][b] += weights_val * input_val;
+ }
+ }
+ }
+ shuffled_input_ptr += 64;
+ shuffled_weights_ptr += 64;
+ }
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ // Add bias value
+ int32 acc = accum[i][b] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The
+ // quantized multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[b * output_depth + c + i] = acc;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
new file mode 100644
index 0000000000..be99240b1f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -0,0 +1,2120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+static constexpr int kDepthwiseReverseShift = -1;
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float* output_data, const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, dilation_width_factor,
+ dilation_height_factor, pad_width, pad_height, output_activation_min,
+ output_activation_max, output_data, output_dims, im2col_data,
+ im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, 1, 1, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims, im2col_data, im2col_dims, gemm_context);
+}
+
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Div(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+inline void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.inputs_count = inputs_count;
+
+ Concatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.inputs_count = inputs_count;
+
+ DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int axis, int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
+ for (int i = 0; i < outputs_count; ++i) {
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
+ }
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
+ output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = zero_point;
+ op_params.scale = scale;
+
+ Dequantize(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, int num_bits, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = num_bits;
+ op_params.minmax.min = rmin;
+ op_params.minmax.max = rmax;
+
+ FakeQuant(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::GatherParams op_params;
+ op_params.input_rank = input_rank;
+
+ Gather(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::MeanParams op_params;
+ op_params.axis_count = reduction_indices.size();
+ for (int i = 0; i < op_params.axis_count; ++i) {
+ op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+ }
+
+ Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+ const Dims<4>& output_dims, const int* permuted_axes) {
+ TransposeParams params;
+ params.perm_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ params.perm[i] = 3 - permuted_axes[3 - i];
+ }
+ Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
+ output);
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ComparisonParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
+ const T* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+ input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ BroadcastComparison4DSlowWithScaling<T, F>(
+ op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+#define TFLITE_LEGACY_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ }
+TFLITE_LEGACY_COMPARISON_OP(Equal);
+TFLITE_LEGACY_COMPARISON_OP(NotEqual);
+TFLITE_LEGACY_COMPARISON_OP(Greater);
+TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
+TFLITE_LEGACY_COMPARISON_OP(Less);
+TFLITE_LEGACY_COMPARISON_OP(LessEqual);
+#undef TFLITE_LEGACY_COMPARISON_OP
+
+template <typename D, typename T>
+inline void Select(const D* input_condition_data,
+ const Dims<4>& input_condition_dims, const T* input_x_data,
+ const Dims<4>& input_x_dims, const T* input_y_data,
+ const Dims<4>& input_y_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ Select(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
+ input_y_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+ const Dims<4>& input_condition_dims,
+ const T* input_x_data, const Dims<4>& input_x_dims,
+ const T* input_y_data, const Dims<4>& input_y_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data,
+ DimsToShape(input_y_dims), input_y_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
+ const T* values, T default_value, T* output_data,
+ const Dims<4>& output_dims, bool value_is_scalar) {
+ SparseToDense(indices, values, default_value, value_is_scalar,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::PackParams op_params;
+ op_params.axis = 3 - dim;
+ op_params.inputs_count = inputs_count;
+
+ Pack(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ tflite::UnpackParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Unpack(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_datas);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, const int32* input_zeropoint,
+ const float* input_scale, int inputs_count, Scalar* output_data,
+ const Dims<4>& output_dims, const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::PackParams op_params;
+ op_params.axis = 3 - dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
+ const RuntimeShape& input_shape, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::ActivationParams params;
+ params.quantized_activation_max = max_value;
+ params.quantized_activation_min = min_value;
+ ReluX(params, input_shape, input_data, output_shape, output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<T>::min();
+ op_params.quantized_activation_max = std::numeric_limits<T>::max();
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Transitional version that will be moved shortly to legacy_reference_ops, as
+// part of RuntimeShape revisions.
+inline void BroadcastMul4DSlow(const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul4DSlow(
+ input1_data, input1_dims, input1_offset, input2_data, input2_dims,
+ input2_offset, output_offset, output_multiplier,
+ //
+ kReverseShift * output_shift,
+ //
+ output_activation_min, output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), beta, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
+ input_beta_left_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
+ input_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+ int input_left_shift, int16* output_data,
+ const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No params in this version.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.output_offset = output_offset;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, T* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<float>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims,
+ const int32_t pad_value) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = pad_value;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = 0;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename Op>
+void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims,
+ Op op) {
+ MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, op);
+}
+
+template <typename T1, typename T2, typename T3>
+void ArgMax(const T3* axis, const T1* input_data,
+ const tflite::Dims<4>& input_dims, T2* output_data,
+ const tflite::Dims<4>& output_dims) {
+ ArgMinMax(DimsToShape(input_dims), input_data, axis, DimsToShape(output_dims),
+ output_data, std::greater<T1>());
+}
+
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+ ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data, cmp);
+}
+
+template <typename T>
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
+ const bool* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data, func);
+}
+
+inline void BroadcastLogical(const bool* input1_data,
+ const Dims<4>& input1_dims,
+ const bool* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction(const T1* input1_data,
+ const Dims<4>& input1_dims,
+ const T2* input2_data,
+ const Dims<4>& input2_dims, R* output_data,
+ const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
+ const T2* input2_data, const Dims<4>& input2_dims,
+ R* output_data, const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+} // namespace reference_ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 2607adc0c1..70d25c4bd9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
+#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
+
namespace tflite {
namespace tensor_utils {
@@ -29,23 +34,31 @@ float PortableClip(float f, float abs_limit) {
return result;
}
+bool PortableIsZeroVector(const float* vector, int v_size) {
+ for (int i = 0; i < v_size; ++i) {
+ if (*vector++ != 0.0f) return false;
+ }
+ return true;
+}
+
void PortableSymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min,
- float* max, float* scaling_factor) {
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor) {
auto minmax = std::minmax_element(values, values + size);
- *min = *minmax.first;
- *max = *minmax.second;
+ *min_value = *minmax.first;
+ *max_value = *minmax.second;
const int kScale = 127;
- const float range = std::max(std::abs(*min), std::abs(*max));
+ const float range = std::max(std::abs(*min_value), std::abs(*max_value));
if (range == 0) {
memset(quantized_values, 0, size * sizeof(int8_t));
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = kScale / range;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value =
- static_cast<int32_t>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
// Clamp: just in case some odd numeric offset.
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
@@ -60,10 +73,12 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
for (int b = 0; b < n_batch; b++) {
const float* matrix_ptr = matrix;
for (int r = 0; r < m_rows; r++) {
+ float dot_prod = 0.0f;
const float* vector_in_batch = vector + b * m_cols;
for (int c = 0; c < m_cols; c++) {
- *result_in_batch += *matrix_ptr++ * *vector_in_batch++;
+ dot_prod += *matrix_ptr++ * *vector_in_batch++;
}
+ *result_in_batch += dot_prod;
result_in_batch += result_stride;
}
}
@@ -75,20 +90,22 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, int result_stride) {
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Get the address of the first row.
- int8_t* row_ptr = (int8_t*)matrix; // NOLINT
+ const int8_t* row_ptr = matrix;
for (row = 0; row < m_rows; ++row, result += result_stride) {
// Initialize the dot product sum for the row to 0.
int32_t dotprod = 0;
+#if defined(__GNUC__)
// Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */);
+#endif
// For every block of 16 8-bit elements (128-bit register) from each row.
for (col = 0; col < m_cols; ++col, ++row_ptr) {
dotprod += (*row_ptr) * (vectors[col]);
} // for col
- *result += (dotprod * batch_scaling_factor_inv);
+ *result += (dotprod * batch_scaling_factor);
} // for row
} // for batch
}
@@ -134,6 +151,16 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = vector[v] * *batch_vector++;
+ }
+ }
+}
+
void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
@@ -146,6 +173,16 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
}
}
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int i = 0; i < v_size; ++i) {
+ batch_vector[i] += vector[i];
+ }
+ batch_vector += v_size;
+ }
+}
+
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector) {
for (int b = 0; b < n_batch; b++) {
@@ -184,6 +221,13 @@ void PortableZeroVector(float* vector, int v_size) {
memset(vector, 0, v_size * sizeof(float));
}
+void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
+ const float scale, float* result) {
+ for (int v = 0; v < v_size; ++v) {
+ *result++ = scale * *vector++;
+ }
+}
+
void PortableClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
for (int v = 0; v < v_size; v++) {
@@ -209,5 +253,31 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
}
}
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+ for (int i = 0; i < v_size; ++i) {
+ sum += input_vector[i];
+ sum_sq += input_vector[i] * input_vector[i];
+ }
+ const float mean = sum / v_size;
+ float stddev_inv = 0.0f;
+ const float variance = sum_sq / v_size - mean * mean;
+ if (variance == 0) {
+ stddev_inv = 1.0f / sqrt(normalization_epsilon);
+ } else {
+ stddev_inv = 1.0f / sqrt(variance);
+ }
+ for (int i = 0; i < v_size; ++i) {
+ output_vector[i] = (input_vector[i] - mean) * stddev_inv;
+ }
+ input_vector += v_size;
+ output_vector += v_size;
+ }
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index 1757a9f5e5..714b1164ee 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -17,7 +17,11 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
namespace tflite {
namespace tensor_utils {
@@ -25,9 +29,11 @@ namespace tensor_utils {
// Limit a float input f between +abs_limit and -abs_limit.
float PortableClip(float f, float abs_limit);
+bool PortableIsZeroVector(const float* vector, int v_size);
+
void PortableSymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min,
- float* max, float* scaling_factor);
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor);
// Multiply a matrix by a batch vector, and store results in a batch-size
// vector.
@@ -63,6 +69,11 @@ void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -76,6 +87,10 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Apply sigmoid to elements of a vector.
void PortableApplySigmoidToVector(const float* vector, int v_size,
float* result);
@@ -94,6 +109,10 @@ void PortableSub1Vector(const float* vector, int v_size, float* result);
// Fill vector with 0.f.
void PortableZeroVector(float* vector, int v_size);
+// Multiply all elements of vector with a scalar.
+void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
// Clip elements of a vector using a abs_limit value.
void PortableClipVector(const float* vector, int v_size, float abs_limit,
float* result);
@@ -110,8 +129,18 @@ void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
void PortableReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+bool IsZeroVector(const float* vector, int v_size) {
+ return PortableIsZeroVector(vector, v_size);
+}
+
void SymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min, float* max,
float* scaling_factor) {
@@ -147,6 +176,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ PortableVectorBatchVectorCwiseProduct(vector, v_size, batch_vector, n_batch,
+ result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
@@ -167,6 +203,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -193,6 +234,12 @@ void ZeroVector(float* vector, int v_size) {
PortableZeroVector(vector, v_size);
}
+// Multiply all elements of vector with a scalar.
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result) {
+ PortableVectorScalarMultiply(vector, v_size, scale, result);
+}
+
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
PortableClipVector(vector, v_size, abs_limit, result);
@@ -208,6 +255,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index e70d8e5454..59f17ae854 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -19,136 +19,177 @@ limitations under the License.
#include <sys/types.h>
#include <algorithm>
#include <cmath>
+#include <functional>
#include <limits>
#include <memory>
#include <type_traits>
-#include "third_party/eigen3/Eigen/Core"
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
+
+// TODO(b/77858996): Add these to gemmlowp.
+template <typename IntegerType>
+IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+template <>
+inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
+ std::int64_t a64 = a;
+ std::int64_t b64 = b;
+ std::int64_t sum = a64 + b64;
+ return static_cast<std::int32_t>(std::min(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
+ std::max(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
+ sum)));
+}
+
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingAddNonGemmlowp(a.raw(), b.raw()));
+}
+
+template <typename IntegerType>
+IntegerType SaturatingSub(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+template <>
+inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t diff = a32 - b32;
+ return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
+}
+
+template <>
+inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
+ std::int64_t a64 = a;
+ std::int64_t b64 = b;
+ std::int64_t diff = a64 - b64;
+ return static_cast<std::int32_t>(std::min(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
+ std::max(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
+ diff)));
+}
+
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingSub(a.raw(), b.raw()));
+}
+// End section to be moved to gemmlowp.
+
namespace reference_ops {
-// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
-// BROADCASTING.
-//
-// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
-// rectangular array of numbers.
-//
-// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
-// However, as Dims<N> is to be deprecated, this class exists as an adaptor
-// to enable simple unoptimized implementations of element-wise broadcasting
-// operations.
-template <int N>
-struct NdArrayDesc {
- // The "extent" of each dimension. Indices along dimension d must be in the
- // half-open interval [0, extents[d]).
- int extents[N];
-
- // The number of *elements* (not bytes) between consecutive indices of each
- // dimension.
- int strides[N];
-};
-
-// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
-// ELEMENT-WISE BROADCASTING.
-//
-// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
-inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
- int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
- TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
- TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
- TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
- return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
- i3 * desc.strides[3];
-}
-
-// Given the dimensions of the operands for an element-wise binary broadcast,
-// adjusts them so that they can be directly iterated over with simple loops.
-// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
-// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
-//
-// This function assumes that the two input shapes are compatible up to
-// broadcasting and the shorter one has already been prepended with 1s to be the
-// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
-// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
-// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
-// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
-//
-// When two shapes are compatible up to broadcasting, for each dimension d,
-// the input extents are either equal, or one of them is 1.
-//
-// This function performs the following for each dimension d:
-// - If the extents are equal, then do nothing since the loop that walks over
-// both of the input arrays is correct.
-// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
-// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
-// array0 to be referenced *at any index* in dimension d and still access the
-// same slice.
-template <int N>
-inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
- const Dims<N>& input1_dims,
- NdArrayDesc<N>* desc0_out,
- NdArrayDesc<N>* desc1_out) {
- TFLITE_DCHECK(desc0_out != nullptr);
- TFLITE_DCHECK(desc1_out != nullptr);
-
- // Copy dims to desc.
- for (int i = 0; i < N; ++i) {
- desc0_out->extents[i] = input0_dims.sizes[i];
- desc0_out->strides[i] = input0_dims.strides[i];
- desc1_out->extents[i] = input1_dims.sizes[i];
- desc1_out->strides[i] = input1_dims.strides[i];
- }
-
- // Walk over each dimension. If the extents are equal do nothing.
- // Otherwise, set the desc with extent 1 to have extent equal to the other and
- // stride 0.
- for (int i = 0; i < N; ++i) {
- const int extent0 = ArraySize(input0_dims, i);
- const int extent1 = ArraySize(input1_dims, i);
- if (extent0 != extent1) {
- if (extent0 == 1) {
- desc0_out->strides[i] = 0;
- desc0_out->extents[i] = extent1;
- } else {
- TFLITE_DCHECK_EQ(extent1, 1);
- desc1_out->strides[i] = 0;
- desc1_out->extents[i] = extent0;
- }
- }
+inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
+ shape->BuildFrom(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value,
+ "Only unsigned integer types handled.");
+ if (integer_input == 0) {
+ return std::numeric_limits<T>::digits;
+ }
+ const T one_in_leading_positive = static_cast<T>(1)
+ << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+}
+
+template <typename IntegerType>
+IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
+ if (exponent == 0) {
+ return x;
}
+ using ScalarIntegerType =
+ typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ const IntegerType min =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
+ const IntegerType max =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
+
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
+ const IntegerType positive_mask =
+ gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
+ const IntegerType negative_mask =
+ gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
+
+ IntegerType result = gemmlowp::ShiftLeft(x, exponent);
+ result = gemmlowp::SelectUsingMask(positive_mask, max, result);
+ result = gemmlowp::SelectUsingMask(negative_mask, min, result);
+ return result;
+}
+
+// If we want to leave IntegerBits fixed, then multiplication
+// by a power of two has to be saturating/rounding, not exact anymore.
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits>
+SaturatingRoundingMultiplyByPOTParam(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- (void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
- TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
- }
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -166,11 +207,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ float input_value = input_data[Offset(
+ input_shape, batch, in_y, in_x, in_channel)];
float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
total += (input_value * filter_value);
}
}
@@ -178,9 +219,9 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ bias_value = bias_data[out_channel];
}
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
ActivationFunctionWithMinMax(total + bias_value,
output_activation_min,
output_activation_max);
@@ -190,77 +231,45 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float* output_data, const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, dilation_width_factor,
- dilation_height_factor, pad_width, pad_height, output_activation_min,
- output_activation_max, output_data, output_dims, im2col_data,
- im2col_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, 1, 1, pad_width, pad_height,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
- output_dims, im2col_data, im2col_dims);
-}
-
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
(void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
(void)gemm_context; // only used in optimized code.
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth =
- MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -271,17 +280,18 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
- int32 input_val = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ int32 input_val = input_data[Offset(input_shape, batch, in_y,
+ in_x, in_channel)];
int32 filter_val =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
acc +=
(filter_val + filter_offset) * (input_val + input_offset);
}
@@ -289,14 +299,14 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ acc += bias_data[out_channel];
}
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(
- acc, output_multiplier, output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
static_cast<uint8>(acc);
}
}
@@ -304,66 +314,30 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
- Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, stride, pad_width,
- pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims, im2col_data, im2col_dims, gemm_context);
-}
-
template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
-
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
+
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width * block_size, output_width);
TFLITE_DCHECK_EQ(input_height * block_size, output_height);
@@ -382,9 +356,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
const int in_h = out_h / block_size;
const int in_b = out_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -394,18 +368,29 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
-
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
+
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
@@ -423,9 +408,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
const int out_h = in_h / block_size;
const int out_b = in_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -434,366 +419,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
}
}
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- float total = 0.f;
- for (int d = 0; d < accum_depth; ++d) {
- total += input_data[b * accum_depth + d] *
- weights_data[out_c * accum_depth + d];
- }
- float bias_value = 0.0f;
- if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
- }
- output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
- total + bias_value, output_activation_min, output_activation_max);
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- int32 acc = 0;
- for (int d = 0; d < accum_depth; ++d) {
- int32 input_val = input_data[b * accum_depth + d];
- int32 filter_val = filter_data[out_c * accum_depth + d];
- acc += (filter_val + filter_offset) * (input_val + input_offset);
- }
- if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
- }
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier,
- output_shift);
- acc += output_offset;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
- }
- }
-}
-
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- TFLITE_DCHECK_EQ(output_offset, 0);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum = bias_data[out_c];
- // Accumulation loop.
- for (int d = 0; d < accum_depth; ++d) {
- int16 input_val = input_data[b * accum_depth + d] + input_offset;
- int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset;
- accum += filter_val * input_val;
- }
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- accum = MultiplyByQuantizedMultiplier(accum, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- accum = std::max(accum, output_activation_min - output_offset);
- accum = std::min(accum, output_activation_max - output_offset);
- accum += output_offset;
- output_data[out_c + output_depth * b] = accum;
- }
- }
-}
-
-inline void ExperimentalShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
-
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK((accum_depth % 16) == 0);
- TFLITE_DCHECK((output_depth % 4) == 0);
-
- // Shuffling and xoring of input activations into the workspace buffer
- uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
- if (batches == 1) {
- for (int i = 0; i < accum_depth; i++) {
- shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
- }
- } else if (batches == 4) {
- for (int c = 0; c < accum_depth; c += 16) {
- for (int b = 0; b < 4; b++) {
- const uint8* src_data_ptr = input_data + b * accum_depth + c;
- for (int j = 0; j < 16; j++) {
- uint8 src_val = *src_data_ptr++;
- // Flip the sign bit, so that the kernel will only need to
- // reinterpret these uint8 values as int8, getting for free the
- // subtraction of the zero_point value 128.
- uint8 dst_val = src_val ^ 0x80;
- *shuffled_input_workspace_ptr++ = dst_val;
- }
- }
- }
- } else {
- TFLITE_DCHECK(false);
- return;
- }
-
- // Actual computation
- if (batches == 1) {
- int16* output_ptr = output_data;
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- // Likewise, we preshuffled and pre-xored the input data above.
- const int8* shuffled_input_data =
- reinterpret_cast<const int8*>(shuffled_input_workspace_data);
- for (int c = 0; c < output_depth; c += 4) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4] = {0};
- // Accumulation loop.
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 16; j++) {
- int8 input_val = shuffled_input_data[d + j];
- int8 weights_val = *shuffled_weights_ptr++;
- accum[i] += weights_val * input_val;
- }
- }
- }
- for (int i = 0; i < 4; i++) {
- // Add bias value
- int acc = accum[i] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_ptr[c + i] = acc;
- }
- }
- } else if (batches == 4) {
- int16* output_ptr = output_data;
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- // Likewise, we preshuffled and pre-xored the input data above.
- const int8* shuffled_input_data =
- reinterpret_cast<const int8*>(shuffled_input_workspace_data);
- for (int c = 0; c < output_depth; c += 4) {
- const int8* shuffled_input_ptr = shuffled_input_data;
- // Accumulation loop.
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4][4];
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- accum[i][b] = 0;
- }
- }
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- for (int j = 0; j < 16; j++) {
- int8 input_val = shuffled_input_ptr[16 * b + j];
- int8 weights_val = shuffled_weights_ptr[16 * i + j];
- accum[i][b] += weights_val * input_val;
- }
- }
- }
- shuffled_input_ptr += 64;
- shuffled_weights_ptr += 64;
- }
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- // Add bias value
- int acc = accum[i][b] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The
- // quantized multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_ptr[b * output_depth + c + i] = acc;
- }
- }
- }
- } else {
- TFLITE_DCHECK(false);
- return;
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
-template <FusedActivationFunctionType Ac>
-void NonGlobalBatchNormalization(
- const float* input_data, const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims, const float* multiplier_data,
- const Dims<4>& multiplier_dims, const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int inner_size = MatchingFlatSizeSkipDim(
- input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
-
- for (int b = 0; b < batches; ++b) {
- for (int i = 0; i < inner_size; ++i) {
- output_data[b * inner_size + i] = ActivationFunction<Ac>(
- (input_data[b * inner_size + i] - mean_data[i]) * multiplier_data[i] +
- offset_data[i]);
- }
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void GlobalBatchNormalization(const float* input_data,
- const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims,
- const float* multiplier_data,
- const Dims<4>& multiplier_dims,
- const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth =
- MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
- offset_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- output_data[depth * i + c] = ActivationFunction<Ac>(
- (input_data[depth * i + c] - mean_data[c]) * multiplier_data[c] +
- offset_data[c]);
- }
- }
-}
-
-inline void Relu(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float lower = 0;
@@ -802,9 +430,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Relu1(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu1(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 1;
@@ -814,9 +443,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Relu6(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 6;
@@ -826,12 +456,31 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void ReluX(const tflite::ActivationParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ const uint8 max_value = params.quantized_activation_max;
+ const uint8 min_value = params.quantized_activation_min;
+ for (int i = 0; i < flat_size; ++i) {
+ const uint8 val = input_data[i];
+ const uint8 clamped =
+ val > max_value ? max_value : val < min_value ? min_value : val;
+ output_data[i] = clamped;
+ }
+}
+
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
float squared_l2_norm = 0;
for (int c = 0; c < depth; ++c) {
@@ -845,15 +494,17 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
- int* output_shift) {
+inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
+ int32* output_inv_sqrt,
+ int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
- const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bits =
+ CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
@@ -888,151 +539,146 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
+ *output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_zero_point, uint8* output_data,
- const Dims<4>& output_dims) {
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- TFLITE_DCHECK_EQ(outer_size, 1);
- int32 square_l2_norm = 0;
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
- square_l2_norm += diff * diff;
- }
- int32 inv_l2norm_multiplier;
- int inv_l2norm_shift;
- GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
- &inv_l2norm_shift);
-
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
- int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
- 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
- int32 unclamped_output_val = 128 + rescaled_diff;
- int32 output_val = std::min(255, std::max(0, unclamped_output_val));
- output_data[Offset(output_dims, i, 0, 0, 0)] =
- static_cast<uint8>(output_val);
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32 input_zero_point = op_params.input_zero_point;
+ for (int i = 0; i < outer_size; ++i) {
+ int32 square_l2_norm = 0;
+ for (int c = 0; c < depth; c++) {
+ int32 diff = input_data[depth * i + c] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int c = 0; c < depth; c++) {
+ int32 diff = input_data[depth * i + c] - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ output_data[depth * i + c] = static_cast<uint8>(output_val);
+ }
}
}
-inline void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+template <typename T>
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] + input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] + input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- const int batches =
- MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
- const int height =
- MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
- const int width =
- MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
- const int depth =
- MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- const int32 input1_val =
- input1_offset + input1_data[Offset(input1_dims, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[Offset(input2_dims, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < size; i++) {
+ auto x = input1_data[i] + input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(
+ x, params.float_activation_min, params.float_activation_max);
}
}
-template <FusedActivationFunctionType Ac>
-inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
- int input1_shift, const int16* input2_data,
- const Dims<4>& input2_dims, int input2_shift,
- int16 output_activation_min, int16 output_activation_max,
- int16* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
+// Element-wise add that can often be used for inner loop of broadcast add as
+// well as the non-broadcast add.
+inline void AddElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, params.input1_multiplier, params.input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, params.input2_multiplier, params.input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[i] = static_cast<uint8>(clamped_output);
}
+}
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ AddElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
- TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
- TFLITE_DCHECK_GE(input1_shift, 0);
- TFLITE_DCHECK_GE(input2_shift, 0);
+inline void Add(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+
+ const int input1_shift = params.input1_shift;
+ const int flat_size =
+ MatchingFlatSize(output_shape, input1_shape, input2_shape);
+ const int16 output_activation_min = params.quantized_activation_min;
+ const int16 output_activation_max = params.quantized_activation_max;
+
+ TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
+ TFLITE_DCHECK_LE(input1_shift, 0);
+ TFLITE_DCHECK_LE(params.input2_shift, 0);
const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
- const int input_shift = input1_shift == 0 ? input2_shift : input1_shift;
+ const int input_right_shift =
+ input1_shift == 0 ? -params.input2_shift : -input1_shift;
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
- F0 scaled_input =
- F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift));
+ F0 scaled_input = F0::FromRaw(
+ gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
const int16 raw_output = result.raw();
const int16 clamped_output = std::min(
@@ -1045,16 +691,24 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
-
+// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1067,49 +721,77 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.float_activation_min, params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
- BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
}
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
-
+inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1122,31 +804,37 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 shifted_input1_val =
+ input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val =
+ input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, params.input1_multiplier,
+ params.input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, params.input2_multiplier,
+ params.input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, params.output_multiplier, params.output_shift) +
+ params.output_offset;
const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1154,120 +842,73 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
}
}
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
-
- int sb1 = y0;
- int sa2 = y0;
- int sb2 = y0 * y1;
- int sa3 = y0 * y2;
- int sa4 = y0 * y2 * y3;
- int sb4 = y0 * y1 * y2;
-
+inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input1_multiplier = unswitched_params.input2_multiplier;
+ switched_params.input1_shift = unswitched_params.input2_shift;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+ switched_params.input2_multiplier = unswitched_params.input1_multiplier;
+ switched_params.input2_shift = unswitched_params.input1_shift;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise add of
+ // sections of the arrays.
uint8* output_data_ptr = output_data;
- for (int i4 = 0; i4 < y4; ++i4) {
- for (int i3 = 0; i3 < y3; ++i3) {
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
- for (int i1 = 0; i1 < y1; ++i1) {
- for (int i0 = 0; i0 < y0; ++i0) {
- const int32 input1_val =
- input1_offset +
- input1_data[i4 * sa4 + i3 * sa3 + i2 * sa2 + i0];
- const int32 input2_val =
- input2_offset +
- input2_data[i4 * sb4 + i2 * sb2 + i1 * sb1 + i0];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- *output_data_ptr = static_cast<uint8>(clamped_output);
- ++output_data_ptr;
- }
+ for (int i3 = 0; i3 < y3; ++i3) {
+ AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
}
+ input1_data_ptr += y4;
}
}
+ input2_data_reset = input2_data_ptr;
}
}
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAdd(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
- input1_multiplier, input1_shift, input2_data, input2_dims,
- input2_offset, input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void BroadcastAddFivefold(
- int y0, int y1, int y2, int y3, int y4, int left_shift,
- const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims,
- input1_offset, input1_multiplier, input1_shift,
- input2_data, input2_dims, input2_offset,
- input2_multiplier, input2_shift, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+template <typename T>
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
@@ -1275,52 +916,57 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+void BroadcastMul4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest
- // stride, typically 1 element).
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
//
// We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for
- // the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] *
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1328,59 +974,127 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+// Element-wise mul that can often be used for inner loop of broadcast Mul as
+// well as the non-broadcast Mul.
+inline void MulElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
+
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ gemmlowp::ScopedProfilingLabel label("Mul/8bit");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ MulElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
- BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
+inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise Mul of
+ // sections of the arrays.
+ uint8* output_data_ptr = output_data;
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ for (int i3 = 0; i3 < y3; ++i3) {
+ MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
+ }
+ input1_data_ptr += y4;
+ }
+ }
+ input2_data_reset = input2_data_ptr;
+ }
}
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+inline void BroadcastMul4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest
- // stride, typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for
- // the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ // The input shapes are extended as part of NdArrayDesc initialization.
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 unclamped_result =
- output_offset +
- MultiplyByQuantizedMultiplierSmallerThanOne(
- input1_val * input2_val, output_multiplier, output_shift);
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, unclamped_result));
- output_data[Offset(output_dims, c, x, y, b)] =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output = std::min(
+ params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1388,12 +1102,14 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16");
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -1405,15 +1121,18 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
+ int32 output_offset = params.output_offset;
+ int32 output_activation_min = params.quantized_activation_min;
+ int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -1431,35 +1150,32 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
- input2_dims, input2_offset, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1472,14 +1188,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for
// the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1487,11 +1203,17 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+template <typename T>
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
@@ -1499,15 +1221,35 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] - input2_data[i], output_activation_min,
- output_activation_max);
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
+
+inline void SubNonBroadcast(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
}
}
@@ -1515,16 +1257,24 @@ inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <typename T>
-void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub");
-
+// TODO(benoitjacob): BroadcastSub is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1537,36 +1287,35 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.float_activation_min, params.float_activation_max);
}
}
}
}
}
-inline void BroadcastSub(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit");
-
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1579,31 +1328,37 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 shifted_input1_val =
+ input1_val * (1 << params.left_shift);
+ const int32 shifted_input2_val =
+ input2_val * (1 << params.left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, params.input1_multiplier,
+ params.input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, params.input2_multiplier,
+ params.input2_shift);
const int32 raw_sub = scaled_input1_val - scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sub, output_multiplier, output_shift) +
- output_offset;
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sub, params.output_multiplier, params.output_shift) +
+ params.output_offset;
const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data[Offset(output_dims, c, x, y, b)] =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, raw_output));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
static_cast<uint8>(clamped_output);
}
}
@@ -1611,31 +1366,193 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK_GT(inputs_count, 1);
- int concat_size = 0;
+inline void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void BroadcastSub4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
+ params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] -
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ }
+ }
+ }
+ }
+}
+
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.quantized_activation_min,
+ params.quantized_activation_max);
+ }
+}
+
+inline void SubWithActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, input2_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] - input2_data[i], params.float_activation_min,
+ params.float_activation_max);
+ }
+}
+
+template <typename Scalar>
+inline void Concatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
}
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= output_shape.Dims(i);
+ }
+
Scalar* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
memcpy(output_ptr, input_data[i] + k * copy_size,
copy_size * sizeof(Scalar));
output_ptr += copy_size;
@@ -1647,37 +1564,48 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
// when optimizng this routine further.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
- // The arguments input_zeropoint and input_scale are expected to be an array
- // that have the quantization parameters for all the inputs to the concat
- // operator.
- TFLITE_DCHECK_GT(inputs_count, 1);
+inline void ConcatenationWithScaling(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
int64_t outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
}
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= output_shape.Dims(i);
+ }
+
const float inverse_output_scale = 1.f / output_scale;
uint8* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
const uint8* input_ptr = input_data[i] + k * copy_size;
if (input_zeropoint[i] == output_zeropoint &&
input_scale[i] == output_scale) {
@@ -1698,64 +1626,203 @@ inline void Concatenation(int concat_dim, const uint8* const* input_data,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void DepthConcatenation(const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
- output_data, output_dims);
+template <typename Scalar>
+void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data, const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ const int dimensions = output_shape.DimensionsCount();
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+
+ int outer_size = 1;
+ for (int i = 0; i < axis; i++) {
+ outer_size *= output_shape.Dims(i);
+ }
+ int copy_size = 1;
+ for (int i = params.axis + 1; i < dimensions; i++) {
+ copy_size *= output_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
+
+ for (int i = 0; i < inputs_count; ++i) {
+ for (int k = 0; k < outer_size; k++) {
+ const Scalar* input_ptr = input_data[i] + copy_size * k;
+ int loc = k * inputs_count * copy_size + i * copy_size;
+ memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
+ }
+ }
+}
+
+template <typename Scalar>
+void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape& output_shape,
+ Scalar* const* output_datas) {
+ const int dimensions = input_shape.DimensionsCount();
+ const int outputs_count = params.num_split;
+
+ int outer_size = 1;
+ for (int i = 0; i < params.axis; i++) {
+ outer_size *= input_shape.Dims(i);
+ }
+ int copy_size = 1;
+ for (int i = params.axis + 1; i < dimensions; i++) {
+ copy_size *= input_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
+
+ for (int i = 0; i < outputs_count; ++i) {
+ for (int k = 0; k < outer_size; k++) {
+ Scalar* output_ptr = output_datas[i] + copy_size * k;
+ int loc = k * outputs_count * copy_size + i * copy_size;
+ memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
+ }
+ }
+}
+
+template <typename Scalar>
+void PackWithScaling(const PackParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int dimensions = output_shape.DimensionsCount();
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
+ int outer_size = 1;
+ for (int i = 0; i < axis; i++) {
+ outer_size *= output_shape.Dims(i);
+ }
+ int copy_size = 1;
+ for (int i = axis + 1; i < dimensions; i++) {
+ copy_size *= output_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
+
+ Scalar* output_ptr = output_data;
+ const float inverse_output_scale = 1.f / output_scale;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ if (input_zeropoint[i] == output_zeropoint &&
+ input_scale[i] == output_scale) {
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ } else {
+ assert(false);
+ const float scale = input_scale[i] * inverse_output_scale;
+ const float bias = -input_zeropoint[i] * scale;
+ auto input_ptr = input_data[i];
+ for (int j = 0; j < copy_size; ++j) {
+ const int32_t value =
+ static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
+ output_zeropoint;
+ output_ptr[j] =
+ static_cast<uint8_t>(std::max(std::min(255, value), 0));
+ }
+ }
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <typename Scalar>
+void DepthConcatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape, Scalar* output_data) {
+ auto params_copy = params;
+ params_copy.axis = 3;
+ Concatenation(params_copy, input_shapes, input_data, output_shape,
+ output_data);
}
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
const int batches =
- MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
- output_state_dims, 3, output_activ_dims, 3);
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
const int height =
- MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
- output_state_dims, 2, output_activ_dims, 2);
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
const int width =
- MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
- output_state_dims, 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
// Concatenate prev_activ and input data together
std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
concat_input_arrays_data.push_back(input_data);
concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
// Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
// Memory state update (the LSTM "guts")
for (int b = 0; b < batches; ++b) {
@@ -1764,24 +1831,24 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
for (int c = 0; c < output_depth; ++c) {
const float input_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 0 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 0 * output_depth + c)]));
const float new_input = std::tanh(activ_temp_data[Offset(
- activ_temp_dims, 1 * output_depth + c, w, h, b)]);
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
const float forget_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 2 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 2 * output_depth + c)]));
const float output_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 3 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 3 * output_depth + c)]));
const float new_state =
input_gate * new_input +
forget_gate *
- prev_state_data[Offset(prev_state_dims, c, w, h, b)];
- output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
- output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
output_gate * std::tanh(new_state);
}
}
@@ -1874,52 +1941,90 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
// aiming for 16-bit fixed-point quantization of these internal nodes here.
//
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -2026,110 +2131,81 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
template <typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int axis, int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- const int batches = ArraySize(*output_dims[0], 3);
- const int height = ArraySize(*output_dims[0], 2);
- const int width = ArraySize(*output_dims[0], 1);
- const int depth = ArraySize(*output_dims[0], 0);
-
- const int slice_size = ArraySize(*output_dims[0], axis);
+void Split(const SplitParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape* const* output_shapes,
+ Scalar* const* output_data) {
+ const int concat_dimensions = input_shape.DimensionsCount();
+ int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
+ int outputs_count = params.num_split;
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
- for (int i = 0; i < outputs_count; ++i) {
- int offset = i * slice_size * input_dims.strides[axis];
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- auto out = Offset(*output_dims[i], c, x, y, b);
- auto in = Offset(input_dims, c, x, y, b);
- output_data[i][out] = input_data[offset + in];
- }
- }
+ int64_t concat_size = 0;
+ for (int i = 0; i < outputs_count; i++) {
+ TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*output_shapes[i], j, input_shape, j);
}
}
+ concat_size += output_shapes[i]->Dims(axis);
}
-}
-
-template <FusedActivationFunctionType Ac, typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- TFLITE_DCHECK_GE(outputs_count, 1);
- for (int i = 0; i < outputs_count; i++) {
- /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
- /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
- /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= input_shape.Dims(i);
}
- // for now we dont have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
-
- TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
- output_data, output_dims);
-}
-
-// TODO(benoitjacob) make this a proper reference impl without Eigen!
-template <typename Scalar>
-using MatrixMap = typename std::conditional<
- std::is_const<Scalar>::value,
- Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
- Eigen::Dynamic, Eigen::Dynamic>>,
- Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
-
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
+ // For all output arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= input_shape.Dims(i);
}
- return MatrixMap<Scalar>(data, rows, cols);
-}
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
- const Dims<N>& dims) {
- const int cols = dims.sizes[N - 1];
- int rows = 1;
- for (int d = 0; d < N - 1; d++) {
- rows *= dims.sizes[d];
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
+ memcpy(output_data[i] + k * copy_size, input_ptr,
+ copy_size * sizeof(Scalar));
+ input_ptr += copy_size;
+ }
}
- return MatrixMap<Scalar>(data, rows, cols);
}
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float total = 0.f;
float filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2139,70 +2215,52 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
total +=
- input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
filter_count++;
}
}
const float average = total / filter_count;
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(average, output_activation_min,
- output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
+ ActivationFunctionWithMinMax(average, params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
int32 acc = 0;
int filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2211,14 +2269,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
++filter_x) {
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
- acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ acc +=
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
filter_count++;
}
}
acc = (acc + filter_count / 2) / filter_count;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ acc = std::max(acc, params.quantized_activation_min);
+ acc = std::min(acc, params.quantized_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(acc);
}
}
@@ -2226,64 +2285,35 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float sum_squares = 0.f;
int filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2293,69 +2323,51 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
const float val =
- input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
sum_squares += val * val;
filter_count++;
}
}
const float l2pool_result = std::sqrt(sum_squares / filter_count);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(l2pool_result, output_activation_min,
- output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
+ ActivationFunctionWithMinMax(l2pool_result,
+ params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float max = std::numeric_limits<float>::lowest();
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
@@ -2365,68 +2377,51 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
const int in_y = in_y_origin + filter_y;
max = std::max(
max,
- input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(max, output_activation_min,
- output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
+ ActivationFunctionWithMinMax(max, params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- TFLITE_DCHECK_GE(output_activation_min, 0);
- TFLITE_DCHECK_LE(output_activation_max, 255);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_GE(params.quantized_activation_min, 0);
+ TFLITE_DCHECK_LE(params.quantized_activation_max, 255);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
uint8 max = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
@@ -2436,12 +2431,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
const int in_y = in_y_origin + filter_y;
max = std::max(
max,
- input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
- max = std::max<uint8>(max, output_activation_min);
- max = std::min<uint8>(max, output_activation_max);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ max = std::max<uint8>(max, params.quantized_activation_min);
+ max = std::min<uint8>(max, params.quantized_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(max);
}
}
@@ -2449,71 +2444,45 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void LocalResponseNormalization(
+ const tflite::LocalResponseNormalizationParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
for (int c = 0; c < depth; ++c) {
- const int begin_input_c = std::max(0, c - range);
- const int end_input_c = std::min(depth, c + range);
+ const int begin_input_c = std::max(0, c - op_params.range);
+ const int end_input_c = std::min(depth, c + op_params.range);
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_data[i * depth + input_c];
accum += input_val * input_val;
}
- const float multiplier = std::pow(bias + alpha * accum, -beta);
+ const float multiplier =
+ std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
}
}
}
-inline void Softmax(const float* input_data, const Dims<4>& input_dims,
- float beta, float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
// Find max element value which we'll use to ensure numerical stability
// taking advantage of the following equality:
- // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
+ // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
float max = std::numeric_limits<float>::lowest();
for (int c = 0; c < depth; ++c) {
max = std::max(max, input_data[i * depth + c]);
@@ -2522,133 +2491,145 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
// Compute sum.
float sum = 0.f;
for (int c = 0; c < depth; ++c) {
- sum += std::exp((input_data[i * depth + c] - max) * beta);
+ sum += std::exp(input_data[i * depth + c] - max);
}
// Compute result.
+ const float log_sum = std::log(sum);
for (int c = 0; c < depth; ++c) {
- output_data[i * depth + c] =
- std::exp((input_data[i * depth + c] - max) * beta) / sum;
+ output_data[i * depth + c] = input_data[i * depth + c] - max - log_sum;
}
}
}
-inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const Dims<4>& output_dims) {
- // The representation chosen for the input to the exp() function is Q5.26.
- // We need to leave extra space since values that we skip might be as large as
- // -32 before multiplying by input_beta_multiplier, and therefore as large as
- // -16 afterwards. Note that exp(-8) is definitely not insignificant to
- // accumulation, but exp(-16) definitely is.
- static const int kScaledDiffIntegerBits = 5;
- static const int kAccumulationIntegerBits = 12;
- using FixedPointScaledDiff =
- gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
- using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+// Although currently the name of this function says that it cannot handle
+// values less than 1, in practice it can handle as low as 1/x_max, where
+// x_max is the largest representable input. In other words, the output range
+// is symmetric.
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1_impl(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
-
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- uint8 max_in_row = 0;
- for (int c = 0; c < depth; ++c) {
- max_in_row = std::max(max_in_row, input_data[i * depth + c]);
- }
-
- FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
- for (int c = 0; c < depth; ++c) {
- int32 input_diff =
- static_cast<int32>(input_data[i * depth + c]) - max_in_row;
- if (input_diff >= diff_min) {
- const int32 input_diff_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_diff, input_beta_multiplier, input_beta_left_shift);
- const FixedPointScaledDiff scaled_diff_f8 =
- FixedPointScaledDiff::FromRaw(input_diff_rescaled);
- sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
- exp_on_negative_values(scaled_diff_f8));
- }
- }
-
- int32 fixed_sum_of_exps = sum_of_exps.raw();
- int headroom_plus_one =
- CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
- // This is the number of bits to the left of the binary point above 1.0.
- // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
- // no later adjustment will be needed.
- int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
- int32 shifted_sum_minus_one = static_cast<int32>(
- (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
- (static_cast<uint32>(1) << 31));
-
- FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
- FixedPoint0::FromRaw(shifted_sum_minus_one));
-
- for (int c = 0; c < depth; ++c) {
- int32 input_diff =
- static_cast<int32>(input_data[i * depth + c]) - max_in_row;
- if (input_diff >= diff_min) {
- const int32 input_diff_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_diff, input_beta_multiplier, input_beta_left_shift);
- const FixedPointScaledDiff scaled_diff_f8 =
- FixedPointScaledDiff::FromRaw(input_diff_rescaled);
-
- FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
- int32 unsat_output = gemmlowp::RoundingDivideByPOT(
- (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
-
- output_data[i * depth + c] = static_cast<uint8>(
- std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
-
- } else {
- output_data[i * depth + c] = 0;
- }
- }
- }
+ // The reason for accumulating the result with an extra bit of headroom is
+ // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
+ // recip_denom will otherwise introduce an error.
+ static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
+
+ const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1488522236, std::log(2.0));
+ const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
+ const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1518500250, std::sqrt(0.5));
+ const FixedPoint0 one_quarter =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
+
+ const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1057819769,
+ 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
+
+ const FixedPointAccum shifted_quarter =
+ gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
+
+ // Reinterpret the input value as Q0.31, because we will figure out the
+ // required shift "ourselves" instead of using, say, Rescale.
+ FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
+ // z_a_pow_2 = input_integer_bits - z_a_headroom;
+ int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
+ FixedPoint0 r_a_tmp =
+ SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
+ const int32 r_a_raw =
+ SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
+ // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
+ // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
+ // InputIntegerBits - z_b_headroom - 0.25);
+ const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ // z_b is treated like z_a, but premultiplying by sqrt(0.5).
+ FixedPoint0 z_b = z_a * sqrt_half;
+ int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
+ const int32 r_b_raw =
+ SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
+ const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
+ const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
+ std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
+
+ const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
+ FixedPoint0 q = r - sqrt_sqrt_half;
+ q = q + q;
+
+ const FixedPoint0 common_sq = q * q;
+ const FixedPoint0 num = q * r + q * common_sq * alpha_n;
+ const FixedPoint0 denom_minus_one_0 =
+ p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
+ const FixedPoint0 recip_denom =
+ one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
+
+ const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
+ return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
+ num_scaled * recip_denom);
}
-inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- // Find max element value which we'll use to ensure numerical stability
- // taking advantage of the following equality:
- // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
- float max = std::numeric_limits<float>::lowest();
- for (int c = 0; c < depth; ++c) {
- max = std::max(max, input_data[i * depth + c]);
- }
-
- // Compute sum.
- float sum = 0.f;
- for (int c = 0; c < depth; ++c) {
- sum += std::exp(input_data[i * depth + c] - max);
- }
+// Minimum output bits to accommodate log of maximum input range. It actually
+// does not matter if one considers, say, [-64,64] or [-64,64).
+//
+// For example, run this through Octave:
+// [0:127; ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
+constexpr int min_log_x_output_bits(int input_bits) {
+ return input_bits > 90
+ ? 7
+ : input_bits > 44
+ ? 6
+ : input_bits > 21
+ ? 5
+ : input_bits > 10
+ ? 4
+ : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
+}
- // Compute result.
- const float log_sum = std::log(sum);
- for (int c = 0; c < depth; ++c) {
- output_data[i * depth + c] = input_data[i * depth + c] - max - log_sum;
- }
- }
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ static_assert(
+ OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
+ "Output integer bits must be sufficent to accommodate logs of inputs.");
+ return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
+ InputIntegerBits>(
+ input_val);
}
-inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_multiplier = params.input_multiplier;
+ const int32 input_left_shift = params.input_left_shift;
+ const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+ const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
- // We need to leave extra space since values that we skip might be as large as
- // -32 before multiplying by input_beta_multiplier, and therefore as large as
- // -16 afterwards. Note that exp(-8) is definitely not insignificant to
- // accumulation, but exp(-16) definitely is.
+ // We need to leave extra space since values that we skip might be as large
+ // as -32 before multiplying by input_beta_multiplier, and therefore as
+ // large as -16 afterwards. Note that exp(-8) is definitely not
+ // insignificant to accumulation, but exp(-16) definitely is.
static constexpr int kScaledDiffIntegerBits = 5;
static constexpr int kAccumulationIntegerBits = 12;
static constexpr int kOutputIntegerBits = 4;
@@ -2657,8 +2638,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
uint8 max_in_row = 0;
@@ -2681,13 +2665,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
- // TODO(b/77858996): Implement fixed-point log().
- // Not a fully-quantized implementation: floating-point log().
- const float float_log_sum_of_exps =
- std::log(static_cast<float>(sum_of_exps.raw()) /
- (1 << (31 - kAccumulationIntegerBits)));
- const int32 fixed_log_sum_of_exps = static_cast<int32>(TfLiteRound(
- float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits))));
+ const int32 fixed_log_sum_of_exps =
+ log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
+ sum_of_exps)
+ .raw();
// rescaled_diff_min is smallest representable in
// Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
@@ -2698,9 +2679,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
const int adjusted_diff_min =
std::max(diff_min - 1, // Note use of > below instead of >= above.
- MultiplyByQuantizedMultiplierSmallerThanOne(
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff =
@@ -2725,9 +2706,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
float val = input_data[i];
@@ -2736,11 +2717,23 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const uint8 input_val_u8 = input_data[i];
@@ -2774,9 +2767,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
- int16* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2792,9 +2786,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
float val = input_data[i];
@@ -2803,12 +2797,24 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int32 output_zero_point = 128;
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const uint8 input_val_u8 = input_data[i];
@@ -2843,15 +2849,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
- int input_left_shift, int16* output_data,
- const Dims<4>& output_dims) {
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const int16* input_data, const RuntimeShape& output_shape,
+ int16* output_data) {
+ const int input_left_shift = params.input_left_shift;
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
TFLITE_DCHECK_LE(input_left_shift, 1);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,
@@ -2876,10 +2883,12 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ int32 zero_point = op_params.zero_point;
+ double scale = op_params.scale;
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int32 val = input_data[i];
@@ -2888,9 +2897,12 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
+inline void FakeQuant(const tflite::FakeQuantParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ float rmin = op_params.minmax.min;
+ float rmax = op_params.minmax.max;
+ int num_bits = op_params.num_bits;
// 0 should always be a representable value. Let's assume that the initial
// min,max range contains 0.
TFLITE_DCHECK_LE(rmin, 0.0f);
@@ -2903,24 +2915,15 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
float nudged_min, nudged_max, nudged_scale;
NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
&nudged_max, &nudged_scale);
- const float inv_nudged_scale = 1.0f / nudged_scale;
-
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
- for (int i = 0; i < flat_size; i++) {
- const float src_val = input_data[i];
- const float clamped = std::min(nudged_max, std::max(nudged_min, src_val));
- const float clamped_shifted = clamped - nudged_min;
- const float dst_val =
- TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale +
- nudged_min;
- output_data[i] = dst_val;
- }
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
+ output_data, flat_size);
}
template <typename SrcT, typename DstT>
-inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
- DstT* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
+ const RuntimeShape& output_shape, DstT* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int offset = i;
@@ -2928,9 +2931,9 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
}
}
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int offset = i;
@@ -2939,44 +2942,79 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
- int stride = input_dims.strides[input_rank - 1];
+inline void Gather(const tflite::GatherParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data, const RuntimeShape& coords_shape,
+ const int32* coords_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_rank = op_params.input_rank;
+ const int gather_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), gather_dimensions);
+ const int axis = gather_dimensions - input_rank;
+ TFLITE_DCHECK_LT(axis, gather_dimensions);
+ TFLITE_DCHECK_GE(axis, 0);
+ const int coords_count = coords_shape.FlatSize();
+ TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis));
+
+ int64_t stride = 1;
+ for (int i = axis + 1; i < gather_dimensions; ++i) {
+ stride *= input_shape.Dims(i);
+ }
T* out = output_data;
- for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ for (int i = 0; i < coords_count; ++i) {
TFLITE_DCHECK_GE(coords_data[i], 0);
- TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis));
const T* in = input_data + coords_data[i] * stride;
memcpy(out, in, sizeof(T) * stride);
out += stride;
}
}
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+template <typename T>
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims, bool align_corners) {
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_size_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
+ int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+
float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
- if (align_corners && output_height > 1) {
+ if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
- if (align_corners && output_width > 1) {
+ if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
@@ -2990,70 +3028,73 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
int32 x0 = static_cast<int32>(std::floor(input_x));
int32 x1 = std::min(x0 + 1, input_width - 1);
for (int c = 0; c < depth; ++c) {
- float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] *
- (1 - (input_y - y0)) *
- (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x0, y1, b)] *
- (input_y - y0) * (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x1, y0, b)] *
- (1 - (input_y - y0)) * (input_x - x0) +
- input_data[Offset(input_dims, c, x1, y1, b)] *
- (input_y - y0) * (input_x - x0);
- output_data[Offset(output_dims, c, x, y, b)] = interpolation;
+ T interpolation =
+ static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
+ (1 - (input_y - y0)) * (1 - (input_x - x0)) +
+ input_data[Offset(input_shape, b, y1, x0, c)] *
+ (input_y - y0) * (1 - (input_x - x0)) +
+ input_data[Offset(input_shape, b, y0, x1, c)] *
+ (1 - (input_y - y0)) * (input_x - x0) +
+ input_data[Offset(input_shape, b, y1, x1, c)] *
+ (input_y - y0) * (input_x - x0));
+ output_data[Offset(output_shape, b, y, x, c)] = interpolation;
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
-}
-
template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+inline void SpaceToBatchND(
+ const SpaceToBatchParams& params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* paddings_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
const int block_shape_height = block_shape_data[0];
const int block_shape_width = block_shape_data[1];
const int padding_top = paddings_data[0];
const int padding_left = paddings_data[2];
+ // For uint8 quantized, the correct padding "zero value" is the output offset.
+ const int32_t pad_value = params.output_offset;
+
for (int out_b = 0; out_b < output_batch_size; ++out_b) {
int input_batch = out_b % input_batch_size;
int shift_w = (out_b / input_batch_size) % block_shape_width;
int shift_h = (out_b / input_batch_size) / block_shape_width;
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
if (out_h * block_shape_height + shift_h < padding_top ||
out_h * block_shape_height + shift_h >=
padding_top + input_height ||
out_w * block_shape_width + shift_w < padding_left ||
out_w * block_shape_width + shift_w >= padding_left + input_width) {
- memset(out, 0, depth * sizeof(T));
+ // This may not execute correctly when pad_value != 0 and T != uint8.
+ memset(out, pad_value, depth * sizeof(T));
} else {
const T* in =
- input_data +
- Offset(input_dims, 0,
- (out_w * block_shape_width + shift_w) - padding_left,
+ input1_data +
+ Offset(input1_shape, input_batch,
(out_h * block_shape_height + shift_h) - padding_top,
- input_batch);
+ (out_w * block_shape_width + shift_w) - padding_left, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -3062,18 +3103,27 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -3095,36 +3145,61 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
if (out_w < 0 || out_w >= output_width) {
continue;
}
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
}
}
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
-
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int left_b_padding = left_paddings[3];
- const int left_h_padding = left_paddings[2];
- const int left_w_padding = left_paddings[1];
- const int left_d_padding = left_paddings[0];
-
- const int right_b_padding = right_paddings[3];
- const int right_h_padding = right_paddings[2];
- const int right_w_padding = right_paddings[1];
- const int right_d_padding = right_paddings[0];
+// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
+// scalar input that provides the padding value. Therefore pad_value_ptr can be
+// equivalent to a simple input1_data. For Pad, it should point to a zero
+// value.
+//
+// Note that two typenames are required, so that T=P=int32 is considered a
+// specialization distinct from P=int32.
+template <typename T, typename P>
+inline void PadImpl(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ const RuntimeShape ext_input_shape =
+ RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+ TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
+ TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
+
+ // Runtime calls are currently fixed at 4 dimensions. Copy inputs so
+ // we can pad them to 4 dims (yes, we are "padding the padding").
+ std::vector<int> left_padding_copy(4, 0);
+ for (int i = 0; i < op_params.left_padding_count; ++i) {
+ left_padding_copy[i] = op_params.left_padding[i];
+ }
+ std::vector<int> right_padding_copy(4, 0);
+ for (int i = 0; i < op_params.right_padding_count; ++i) {
+ right_padding_copy[i] = op_params.right_padding[i];
+ }
+
+ const int output_batch = ext_output_shape.Dims(0);
+ const int output_height = ext_output_shape.Dims(1);
+ const int output_width = ext_output_shape.Dims(2);
+ const int output_depth = ext_output_shape.Dims(3);
+
+ const int left_b_padding = left_padding_copy[0];
+ const int left_h_padding = left_padding_copy[1];
+ const int left_w_padding = left_padding_copy[2];
+ const int left_d_padding = left_padding_copy[3];
+
+ const int right_b_padding = right_padding_copy[0];
+ const int right_h_padding = right_padding_copy[1];
+ const int right_w_padding = right_padding_copy[2];
+ const int right_d_padding = right_padding_copy[3];
+
+ const T pad_value = *pad_value_ptr;
const T* in_ptr = input_data;
T* out_ptr = output_data;
@@ -3150,69 +3225,83 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims,
}
}
-// Legacy Pad() method that casts an int32_t to T before padding.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
+template <typename T, typename P>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
}
+// The second (pad-value) input can be int32 when, say, the first is uint8.
template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- Pad(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, 0);
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ const T converted_pad_value = static_cast<T>(*pad_value_ptr);
+ PadImpl(op_params, input_shape, input_data, &converted_pad_value,
+ output_shape, output_data);
+}
+
+// This version avoids conflicting template matching.
+template <>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const int32* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ int32* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
}
template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- // Note that the axis orders are reversed for runtime ops, so the indices,
- // strides and masks must be as well too.
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
- const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 3);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
- const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 2);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
- const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 1);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
- const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 0);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ // Note that the output_shape is not used herein.
+ tflite::StridedSliceParams params_copy = op_params;
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ // Reverse and pad to 4 dimensions because that is what the runtime code
+ // requires (ie. all shapes must be 4D and are given backwards).
+ strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+ const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
+ const int stop_b =
+ strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+ const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
+ const int stop_h =
+ strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+ const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
+ const int stop_w =
+ strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+ const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
+ const int stop_d =
+ strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
+ !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+ in_b += params_copy.strides[0]) {
for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
+ !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+ in_h += params_copy.strides[1]) {
for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+ in_w += params_copy.strides[2]) {
+ for (int in_d = start_d; !strided_slice::LoopCondition(
+ in_d, stop_d, params_copy.strides[3]);
+ in_d += params_copy.strides[3]) {
+ *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
}
}
}
@@ -3220,31 +3309,39 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- // TODO(dkalenichenko): This op only supports 4D tensors.
- TFLITE_DCHECK_EQ(begin.size(), 4);
- TFLITE_DCHECK_EQ(size.size(), 4);
- const int start_b = begin[3];
- const int stop_b =
- size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
- const int start_h = begin[2];
- const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
- const int start_w = begin[1];
- const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
- const int start_d = begin[0];
- const int stop_d =
- size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+inline void Slice(const tflite::SliceParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
+ TFLITE_DCHECK_LE(op_params.begin_count, 4);
+ TFLITE_DCHECK_LE(op_params.size_count, 4);
+ const int begin_count = op_params.begin_count;
+ const int size_count = op_params.size_count;
+ // We front-pad the begin and size vectors.
+ const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
+ const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
+ ? ext_shape.Dims(0) - start_b
+ : start_b + op_params.size[0];
+ const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
+ const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
+ ? ext_shape.Dims(1) - start_h
+ : start_h + op_params.size[size_count - 3];
+ const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
+ const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
+ ? ext_shape.Dims(2) - start_w
+ : start_w + op_params.size[size_count - 2];
+ const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
+ const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
+ ? ext_shape.Dims(3) - start_d
+ : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
for (int in_d = start_d; in_d < stop_d; ++in_d) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ *out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)];
}
}
}
@@ -3259,63 +3356,170 @@ inline void Exp(const T* input_data, const size_t num_elements,
}
}
+// A generic reduce method that can be used for reduce_sum, reduce_mean, etc.
+// This method iterates through input data and reduce elements along the
+// dimensions given in axis.
+template <typename In, typename Out>
+inline bool Reduce(const In* input_data, const int* input_dims,
+ const int* output_dims, const int input_num_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis, int* input_iter,
+ Out reducer(const Out current, const In in),
+ Out* output_data) {
+ // Reset input iterator.
+ for (int idx = 0; idx < input_num_dims; ++idx) {
+ input_iter[idx] = 0;
+ }
+ // Iterate through input_data.
+ do {
+ size_t input_offset =
+ ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr);
+ size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims,
+ input_iter, num_axis, axis);
+ output_data[output_offset] =
+ reducer(output_data[output_offset], input_data[input_offset]);
+ } while (NextIndex(input_num_dims, input_dims, input_iter));
+ return true;
+}
+
+inline bool ResolveAxis(const int num_dims, const int* axis,
+ const int64_t num_axis, int* out_axis,
+ int* out_num_axis) {
+ *out_num_axis = 0; // Just in case.
+ // Short-circuit axis resolution for scalars; the axis will go unused.
+ if (num_dims == 0) {
+ return true;
+ }
+ // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
+ for (int64_t idx = 0; idx < num_axis; ++idx) {
+ // Handle negative index.
+ int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
+ TFLITE_DCHECK(current >= 0 && current < num_dims);
+ bool is_dup = false;
+ for (int j = 0; j < *out_num_axis; ++j) {
+ if (out_axis[j] == current) {
+ is_dup = true;
+ break;
+ }
+ }
+ if (!is_dup) {
+ out_axis[*out_num_axis] = current;
+ *out_num_axis += 1;
+ }
+ }
+ return true;
+}
+
+// This method expects that output_data has been initialized.
+template <typename In, typename Out>
+inline bool ReduceSumImpl(const In* input_data, const int* input_dims,
+ const int* output_dims, const int input_num_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis, int* input_iter,
+ Out* output_data) {
+ auto reducer = [](const Out current, const In in) -> Out {
+ const Out actual_in = static_cast<Out>(in);
+ return current + actual_in;
+ };
+ return Reduce<In, Out>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, axis, num_axis, input_iter, reducer,
+ output_data);
+}
+
+template <typename T>
+inline bool InitTensorDataForReduce(const int* dims, const int num_dims,
+ const T init_value, T* data) {
+ size_t num_elements = 1;
+ for (int idx = 0; idx < num_dims; ++idx) {
+ size_t current = static_cast<size_t>(dims[idx]);
+ // Overflow prevention.
+ if (num_elements > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_elements *= current;
+ }
+ for (size_t idx = 0; idx < num_elements; ++idx) {
+ data[idx] = init_value;
+ }
+ return true;
+}
+
+// Computes the generic value (i.e., sum/max/min/prod) of elements across
+// dimensions given in axis. It needs to pass in init_value and reducer.
+template <typename T>
+inline bool ReduceGeneric(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int64_t num_axis_dimensions,
+ bool keep_dims, int* temp_index, int* resolved_axis,
+ T init_value,
+ T reducer(const T current, const T in)) {
+ // Reset output data.
+ if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
+ output_data)) {
+ return false;
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, reducer, output_data);
+}
+
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis.
template <typename T, typename U>
inline bool Mean(const T* input_data, const int* input_dims,
const int input_num_dims, T* output_data,
const int* output_dims, const int output_num_dims,
const int* axis, const int num_axis_dimensions, bool keep_dims,
int* temp_index, int* resolved_axis, U* temp_sum) {
- // resets output data.
+ // Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
- num_outputs *= static_cast<size_t>(output_dims[idx]);
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
}
for (size_t idx = 0; idx < num_outputs; ++idx) {
output_data[idx] = T();
temp_sum[idx] = U();
}
- // resets temp index.
- for (int idx = 0; idx < input_num_dims; ++idx) {
- temp_index[idx] = 0;
- }
- // resolves axis.
+
+ // Resolve axis.
int num_resolved_axis = 0;
- for (int idx = 0; idx < num_axis_dimensions; ++idx) {
- int current = axis[idx];
- TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0);
- if (current < 0) {
- current += input_num_dims;
- }
- bool is_dup = false;
- for (int j = 0; j < num_resolved_axis; ++j) {
- if (resolved_axis[j] == current) {
- is_dup = true;
- break;
- }
- }
- if (!is_dup) {
- resolved_axis[num_resolved_axis++] = current;
- }
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
}
- // iterates through input_data.
- for (bool has_next = true; has_next;
- has_next = NextIndex(input_num_dims, input_dims, temp_index)) {
- size_t input_offset =
- ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr);
- size_t output_offset =
- ReducedOutputOffset(input_num_dims, input_dims, temp_index,
- num_resolved_axis, resolved_axis);
- temp_sum[output_offset] += static_cast<U>(input_data[input_offset]);
- }
- // takes average by num of elements added to get mean.
- size_t num_elements_in_axis = 1;
+
+ if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, temp_sum)) {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ U num_elements_in_axis = 1;
for (int idx = 0; idx < num_resolved_axis; ++idx) {
size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ // Overflow prevention.
if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
return false;
}
num_elements_in_axis *= current;
}
+
if (num_elements_in_axis > 0) {
for (size_t idx = 0; idx < num_outputs; ++idx) {
output_data[idx] =
@@ -3326,22 +3530,32 @@ inline bool Mean(const T* input_data, const int* input_dims,
}
template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
+inline void Mean(const tflite::MeanParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mean");
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_batch = output_shape.Dims(0);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int output_depth = output_shape.Dims(3);
+
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
// The current implementation only supports simultaneous reduction over
// width and height.
- TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
- TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
- (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+ TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+ (op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
@@ -3350,52 +3564,97 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
float value = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
- value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
}
}
- output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
value / (input_width * input_height);
}
}
}
-template <typename T>
-void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
- const Dims<4>& input2_dims, T* output_data,
- const Dims<4>& output_dims) {
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis for quantized values.
+template <typename T, typename U>
+inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
+ float input_scale, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ int32 output_zero_point, float output_scale,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ bool compute_sum) {
+ // Reset output data.
+ size_t num_outputs = 1;
+ for (int idx = 0; idx < output_num_dims; ++idx) {
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
+ }
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ output_data[idx] = T();
+ temp_sum[idx] = U();
+ }
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- }
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, temp_sum)) {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ U num_elements_in_axis = 1;
+ for (int idx = 0; idx < num_resolved_axis; ++idx) {
+ size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ // Overflow prevention.
+ if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+ return false;
+ }
+ num_elements_in_axis *= current;
+ }
+
+ if (num_elements_in_axis > 0) {
+ const float scale = input_scale / output_scale;
+ if (compute_sum) {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) +
+ output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ } else {
+ const float bias = -input_zero_point * scale + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) +
+ output_zero_point;
}
}
}
+ return true;
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims);
+void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int flat_size = MatchingFlatSize(input1_shape, output_shape);
auto min_value = input2_data[0];
for (int i = 0; i < flat_size; i++) {
@@ -3403,11 +3662,21 @@ void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims);
+void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int flat_size = MatchingFlatSize(input1_shape, output_shape);
auto max_value = input2_data[0];
for (int i = 0; i < flat_size; i++) {
@@ -3415,22 +3684,41 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T, typename Op>
-void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims,
- Op op) {
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data, Op op) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- auto out_idx = Offset(output_dims, c, x, y, b);
- auto in1_idx = SubscriptToIndex(desc1, c, x, y, b);
- auto in2_idx = SubscriptToIndex(desc2, c, x, y, b);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = op(in1_val, in2_val);
@@ -3440,9 +3728,10 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
}
-template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims) {
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -3450,67 +3739,121 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = ArraySize(input_dims, 0);
+ const int trailing_dim = output_shape.DimensionsCount() - 1;
+ TFLITE_DCHECK_EQ(input1_shape.DimensionsCount(),
+ output_shape.DimensionsCount());
+ TFLITE_DCHECK_EQ(output_shape.Dims(trailing_dim), 1);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input1_shape, trailing_dim, output_shape);
+ const int depth = input1_shape.Dims(trailing_dim);
for (int i = 0; i < outer_size; ++i) {
- auto max_value = input_data[i * depth];
- int max_index = 0;
+ auto min_max_value = input1_data[i * depth];
+ int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
- const auto& curr_value = input_data[i * depth + d];
- if (curr_value > max_value) {
- max_value = curr_value;
- max_index = d;
+ const auto& curr_value = input1_data[i * depth + d];
+ if (cmp(curr_value, min_max_value)) {
+ min_max_value = curr_value;
+ min_max_index = d;
}
}
- output_data[i] = max_index;
+ output_data[i] = min_max_index;
}
}
+template <typename T1, typename T2, typename T3>
+void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data) {
+ ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
+ std::greater<T1>());
+}
+
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const RuntimeShape& input2_shape, const T3* input2_data,
+ const RuntimeShape& output_shape, T2* output_data) {
+ // Drop shape of second input: not needed.
+ ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, int* permuted_axes) {
+void Transpose(const TransposeParams& params,
+ const RuntimeShape& unextended_input_shape, const T* input_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ const int unextended_output_size = unextended_output_shape.DimensionsCount();
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_size, 4);
+ TFLITE_DCHECK_EQ(unextended_output_size, params.perm_count);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+ const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
+ const int output_ext_size = 4 - unextended_output_size;
+
+ // The perm data is extended to match the output, each index incremented by
+ // the amount of front padding of the input shape.
+ int extended_perm[4];
+ for (int i = 0; i < output_ext_size; ++i) {
+ extended_perm[i] = i;
+ }
+ for (int i = 0; i < unextended_output_size; ++i) {
+ extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
+ }
+
int out_sizes[4];
// Compute the inverse permutation array so we can do an output centered
// transpose. Also, check to make sure output_dims is matching input_dims.
for (int k = 0; k < 4; k++) {
- out_sizes[k] =
- MatchingArraySize(input_dims, permuted_axes[k], output_dims, k);
+ out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
}
// Naive transpose loop (iterate on output index and compute input index).
int o[4]; // loop index (on output).
int i[4];
for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
- i[permuted_axes[3]] = o[3];
+ i[extended_perm[3]] = o[3];
for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
- i[permuted_axes[2]] = o[2];
+ i[extended_perm[2]] = o[2];
for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
- i[permuted_axes[1]] = o[1];
+ i[extended_perm[1]] = o[1];
for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
- i[permuted_axes[0]] = o[0];
- output[Offset(output_dims, o)] = input[Offset(input_dims, i)];
+ i[extended_perm[0]] = o[0];
+ output_data[Offset(output_shape, o)] =
+ input_data[Offset(input_shape, i)];
}
}
}
}
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
// Although transpose convolution simplifies to convolution with transposed
// weights for strides of 1, non-unitary striding complicates matters. To
@@ -3519,7 +3862,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// computing their influence on the output, rather than looping through the
// output elements in the typical "gather" access pattern of a conv. We
// therefore must initialize the output array to zero.
- for (int i = 0; i < FlatSize(output_dims); i++) {
+ const int num_elements = output_shape.FlatSize();
+ for (int i = 0; i < num_elements; i++) {
output_data[i] = 0.0f;
}
@@ -3541,13 +3885,14 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// We cannot accumulate out of bounds
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
(out_y < output_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ float input_value = input_data[Offset(
+ input_shape, batch, in_y, in_x, in_channel)];
float filter_value =
- filter_data[Offset(filter_dims, out_channel, filter_x,
- filter_y, in_channel)];
- output_data[Offset(output_dims, out_channel, out_x, out_y,
- batch)] += input_value * filter_value;
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
+ output_data[Offset(output_shape, batch, out_y, out_x,
+ out_channel)] +=
+ input_value * filter_value;
}
}
}
@@ -3559,6 +3904,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline bool EqualFn(T lhs, T rhs) {
+ return lhs == rhs;
+}
+
+template <typename T>
+inline bool NotEqualFn(T lhs, T rhs) {
+ return lhs != rhs;
+}
+
+template <typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs;
}
@@ -3579,89 +3934,144 @@ template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
-inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
+inline void ComparisonImpl(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] = F(input1_data[i], input2_data[i]);
}
}
+template <ComparisonFn<float> F>
+inline void Comparison(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape, bool* output_data) {
+ ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+}
+
template <typename T, ComparisonFn<int32> F>
-inline void Comparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, bool* output_data,
- const Dims<4>& output_dims) {
+inline void ComparisonWithScaling(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const int32 input1_val = input1_offset + input1_data[i];
const int32 input2_val = input2_offset + input2_data[i];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier, input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison(const T* input1_data,
- const Dims<4>& input1_dims,
- const T* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims) {
+inline void BroadcastComparison4DSlowImpl(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
+ F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}
+template <ComparisonFn<float> F>
+inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ bool* output_data) {
+ BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
+ input2_shape, input2_data,
+ output_shape, output_data);
+}
template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 input2_multiplier, int input2_shift,
- bool* output_data, const Dims<4>& output_dims) {
+inline void BroadcastComparison4DSlowWithScaling(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, input2_multiplier, input2_shift);
- output_data[Offset(output_dims, c, x, y, b)] =
+ output_data[Offset(output_shape, b, y, x, c)] =
F(scaled_input1_val, scaled_input2_val);
}
}
@@ -3669,52 +4079,71 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
}
}
-#define TFLITE_COMPARISON_OP(name) \
- template <typename T> \
- inline void name(const T* input1_data, const Dims<4>& input1_dims, \
- const T* input2_data, const Dims<4>& input2_dims, \
- bool* output_data, const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name); \
- Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, input1_shift, \
- input2_data, input2_dims, input2_offset, \
- input2_multiplier, input2_shift, output_data, \
- output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
- const Dims<4>& input2_dims, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
- BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
+#define TFLITE_COMPARISON_OP(name) \
+ inline void name(const ComparisonParams& op_params, \
+ const RuntimeShape& input1_shape, const float* input1_data, \
+ const RuntimeShape& input2_shape, const float* input2_data, \
+ const RuntimeShape& output_shape, bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
+ input2_data, output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void name##NoScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
+ ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, output_shape, \
+ output_data); \
+ } \
+ template <typename T> \
+ inline void name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
+ ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##NoScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
+ BroadcastComparison4DSlowImpl<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ inline void Broadcast4DSlow##name( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const float* input1_data, const RuntimeShape& input2_shape, \
+ const float* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
+ BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
+ BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
}
+TFLITE_COMPARISON_OP(Equal);
+TFLITE_COMPARISON_OP(NotEqual);
TFLITE_COMPARISON_OP(Greater);
TFLITE_COMPARISON_OP(GreaterEqual);
TFLITE_COMPARISON_OP(Less);
@@ -3722,13 +4151,13 @@ TFLITE_COMPARISON_OP(LessEqual);
#undef TFLITE_COMPARISON_OP
template <typename D, typename T>
-inline void Select(const D* input_condition_data,
- const Dims<4>& input_condition_dims, const T* input_x_data,
- const Dims<4>& input_x_dims, const T* input_y_data,
- const Dims<4>& input_y_dims, T* output_data,
- const Dims<4>& output_dims) {
- const int64_t flatsize =
- MatchingFlatSize(input_x_dims, input_y_dims, output_dims);
+void Select(const RuntimeShape& input_condition_shape,
+ const D* input_condition_data, const RuntimeShape& input_x_shape,
+ const T* input_x_data, const RuntimeShape& input_y_shape,
+ const T* input_y_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int64_t flatsize = MatchingFlatSize(
+ input_condition_shape, input_x_shape, input_y_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] =
input_condition_data[i] ? input_x_data[i] : input_y_data[i];
@@ -3736,24 +4165,210 @@ inline void Select(const D* input_condition_data,
}
template <typename D, typename T>
-inline void RankOneSelect(const D* input_condition_data,
- const Dims<4>& input_condition_dims,
- const T* input_x_data, const Dims<4>& input_x_dims,
- const T* input_y_data, const Dims<4>& input_y_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int64_t rank = MatchingArraySize(input_condition_dims, 0, input_x_dims,
- 3, input_y_dims, 3, output_dims, 3);
+void RankOneSelect(const RuntimeShape& input_condition_shape,
+ const D* input_condition_data,
+ const RuntimeShape& input_x_shape, const T* input_x_data,
+ const RuntimeShape& input_y_shape, const T* input_y_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int64_t outer_size = input_condition_shape.FlatSize();
+ TFLITE_DCHECK_EQ(
+ MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
+ outer_size);
const int64_t inner_size =
- MatchingFlatSizeSkipDim(input_x_dims, 3, input_y_dims, output_dims);
+ MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
int64_t offset = 0;
- for (int64_t i = 0; i < rank; i++) {
+ for (int64_t i = 0; i < outer_size; i++) {
const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
offset += inner_size;
}
}
+// For easy implementation, the indices is always a vector of size-4 vectors.
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
+ const T* values, T default_value,
+ bool value_is_scalar,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+ const int value_count = indices.size();
+
+ // First fill the output_data with default value.
+ const int num_elements = output_shape.FlatSize();
+ for (int i = 0; i < num_elements; ++i) {
+ output_data[i] = default_value;
+ }
+
+ // Special handle for value is scalar case to avoid checking the boolean
+ // condition within the loop every time.
+ if (value_is_scalar) {
+ for (int i = 0; i < value_count; ++i) {
+ const std::vector<TI>& index = indices[i];
+ TFLITE_DCHECK_EQ(index.size(), 4);
+ const T value = *values; // just use the first value.
+ output_data[Offset(output_shape, index[0], index[1], index[2],
+ index[3])] = value;
+ }
+ return;
+ }
+
+ // Go through the values and indices to fill the sparse values.
+ for (int i = 0; i < value_count; ++i) {
+ const std::vector<TI>& index = indices[i];
+ TFLITE_DCHECK_EQ(index.size(), 4);
+ const T value = values[i];
+ output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
+ value;
+ }
+}
+
+template <typename T>
+inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = std::pow(input1_data[i], input2_data[i]);
+ }
+}
+
+template <typename T>
+inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = std::pow(in1_val, in2_val);
+ }
+ }
+ }
+ }
+}
+
+inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
+ const RuntimeShape& input2_shape, const bool* input2_data,
+ const RuntimeShape& output_shape, bool* output_data,
+ const std::function<bool(bool, bool)>& func) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
+inline void BroadcastLogical4DSlow(
+ const RuntimeShape& unextended_input1_shape, const bool* input1_data,
+ const RuntimeShape& unextended_input2_shape, const bool* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data,
+ const std::function<bool(bool, bool)>& func) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = func(in1_val, in2_val);
+ }
+ }
+ }
+ }
+}
+
+// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
+// generalized and efficient BroadcastBinaryFunction.
+//
+// Also appears to duplicte MinimumMaximum.
+//
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction4DSlow(
+ const RuntimeShape& unextended_input1_shape, const T1* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T2* input2_data,
+ const RuntimeShape& unextended_output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = func(in1_val, in2_val);
+ }
+ }
+ }
+ }
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+// TODO(renjieliu): Refactor other binary functions to use this one.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const RuntimeShape& input1_shape,
+ const T1* input1_data,
+ const RuntimeShape& input2_shape,
+ const T2* input2_data,
+ const RuntimeShape& output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
new file mode 100644
index 0000000000..7d44296134
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ for (int i = 0; i < outer_size; ++i) {
+ // Find max element value which we'll use to ensure numerical stability
+ // taking advantage of the following equality:
+ // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
+ float max = std::numeric_limits<float>::lowest();
+ for (int c = 0; c < depth; ++c) {
+ max = std::max(max, input_data[i * depth + c]);
+ }
+
+ // Compute sum.
+ float sum = 0.f;
+ for (int c = 0; c < depth; ++c) {
+ sum += std::exp((input_data[i * depth + c] - max) * params.beta);
+ }
+
+ // Compute result.
+ for (int c = 0; c < depth; ++c) {
+ output_data[i * depth + c] =
+ std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
+ }
+ }
+}
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ for (int i = 0; i < outer_size; ++i) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row = std::max(max_in_row, input_data[i * depth + c]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ int headroom_plus_one =
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[i * depth + c] = static_cast<uint8>(
+ std::max(std::min(unsat_output, static_cast<int32>(255)),
+ static_cast<int32>(0)));
+
+ } else {
+ output_data[i * depth + c] = 0;
+ }
+ }
+ }
+}
+
+// Performs softmax along the input of size (input_size * batch_size).
+inline void Softmax(const float* in, const int input_size, const int batch_size,
+ const float beta, float* out) {
+ // TF_LITE_ASSERT(input_size > 0);
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Find the max coeff.
+ float max_coeff = in[0];
+ for (int i = 1; i < input_size; i++) {
+ if (in[i] > max_coeff) max_coeff = in[i];
+ }
+
+ // Compute the normalized sum of exps.
+ float exp_sum = 0.0;
+ for (int i = 0; i < input_size; i++) {
+ out[i] = std::exp((in[i] - max_coeff) * beta);
+ exp_sum += out[i];
+ }
+
+ // Divide by the sum of exps.
+ float reciprocal_sum_exp = 1.f / exp_sum;
+ for (int i = 0; i < input_size; i++) {
+ out[i] *= reciprocal_sum_exp;
+ }
+
+ // Advance in and out pointers for the next batch.
+ in += input_size;
+ out += input_size;
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
new file mode 100644
index 0000000000..15df31f75a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace {
+template <typename T>
+void TestOneResizeBilinear(int batch, int depth, int input_width,
+ int input_height, int output_width,
+ int output_height, float error_threshold) {
+ RuntimeShape input_dims_inference({batch, input_height, input_width, depth});
+ RuntimeShape output_dims_inference(
+ {batch, output_height, output_width, depth});
+
+ const int input_buffer_size = input_dims_inference.FlatSize();
+ const int output_buffer_size = output_dims_inference.FlatSize();
+
+ std::vector<T> input_data(input_buffer_size, 0);
+ std::vector<T> reference_output_data(output_buffer_size, 0);
+ // Initialize the output data with something other than zero, so we can catch
+ // issue with kernels failing to initialize the output.
+ std::vector<T> output_data(output_buffer_size, 3);
+
+ const T min_amplitude = static_cast<T>(0);
+ const T max_amplitude = static_cast<T>(255);
+ FillRandom(&input_data, min_amplitude, max_amplitude);
+
+ RuntimeShape output_size_dims({1, 1, 1, 2});
+ std::vector<int32> output_size_data = {output_height, output_width};
+
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = false;
+
+ reference_ops::ResizeBilinear(op_params, input_dims_inference,
+ input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference,
+ reference_output_data.data());
+ optimized_ops::ResizeBilinear(
+ op_params, input_dims_inference, input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference, output_data.data());
+
+ double sum_diff = 0;
+ float max_abs_val = 0;
+ for (int i = 0; i < output_buffer_size; i++) {
+ sum_diff += std::abs(static_cast<float>(output_data[i]) -
+ static_cast<float>(reference_output_data[i]));
+ max_abs_val = std::max(
+ max_abs_val, std::abs(static_cast<float>(reference_output_data[i])));
+ }
+
+ if (sum_diff != 0.f) {
+ const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
+ const float relative_error = std::abs(mean_diff) / max_abs_val;
+ ASSERT_LT(relative_error, error_threshold);
+ }
+}
+
+TEST(ResizeBilinear, TestResizeBilinear8Bit) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+
+ TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
+ output_width, output_height, 0.025);
+ }
+}
+
+TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = input_width * 2;
+ const int output_height = input_height * 2;
+
+ TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
+ output_width, output_height, 1e-5);
+ }
+}
+
+TEST(ResizeBilinear, TestResizeBilinear) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+
+ TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
+ output_width, output_height, 1e-5);
+ }
+}
+
+TEST(ResizeBilinear2x2, TestResizeBilinear) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = input_width * 2;
+ const int output_height = input_height * 2;
+
+ TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
+ output_width, output_height, 1e-5);
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
new file mode 100644
index 0000000000..831fb3c243
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace {
+
+void RunSoftmaxFloatReference(const uint8* input_data,
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
+ uint8* reference_output_data) {
+ const int ref_buffer_size = shape_common.FlatSize();
+ std::vector<float> reference_dequant_data(ref_buffer_size);
+ std::vector<float> reference_output_float_data(ref_buffer_size);
+
+ // Reference data generated via Dequant of input into float, and then applying
+ // float Softmax.
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ sm_params.beta = beta;
+ optimized_ops::Softmax(sm_params, shape_common, reference_dequant_data.data(),
+ shape_common, reference_output_float_data.data());
+ // Work with quantized scaling for Softmax, under which 256 represents 1, but
+ // we limit this to 255.
+ for (int i = 0; i < ref_buffer_size; i++) {
+ reference_output_data[i] = std::min(
+ 255,
+ static_cast<int>(std::round(256.0f * reference_output_float_data[i])));
+ }
+}
+
+void CheckOutputData(const uint8* test_output, const uint8* reference_output,
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
+ // While calculating some metrics in floating point, we work with quantized
+ // scaling.
+ std::vector<int> diff(buffer_size);
+ int64_t sum_diff = 0;
+ int64_t sum_abs_diff = 0;
+ for (int i = 0; i < buffer_size; i++) {
+ diff[i] = static_cast<int>(test_output[i]) - reference_output[i];
+ sum_diff += diff[i];
+ sum_abs_diff += std::abs(diff[i]);
+ }
+ // These stats help understand test failures.
+ std::sort(std::begin(diff), std::end(diff));
+ const int min_diff = diff.front();
+ const int max_diff = diff.back();
+ const int median_diff = diff[diff.size() / 2];
+ const float mean_diff = static_cast<float>(sum_diff) / buffer_size;
+ const float mean_abs_diff = static_cast<float>(sum_abs_diff) / buffer_size;
+ // We either check for bit exactness (against the reference quantized version)
+ // or for general accuracy, allowing off-by-one (against the float reference).
+ if (be_exacting) {
+ ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0);
+ } else {
+ // For small numbers of samples, the estimates of the means vary more.
+ // Rather than widen the tolerances, we skip the smaller tests.
+ ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) ||
+ buffer_size < 10000) &&
+ std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
+ std::abs(max_diff) <= 1);
+ }
+}
+
+// Runs the Softmax and compares against the float reference implementation and
+// the quantized reference implementation.
+void RunOneSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
+ std::vector<uint8> optimized_softmax_output(buffer_size);
+ std::vector<uint8> reference_float_softmax_output(buffer_size);
+ std::vector<uint8> reference_quant_softmax_output(buffer_size);
+
+ RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale,
+ stride, beta, reference_float_softmax_output.data());
+
+ int32 input_beta_multiplier;
+ int input_beta_left_shift;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits,
+ &input_beta_multiplier,
+ &input_beta_left_shift);
+ // diff_min has a negative value, and is used to limit the maximum magnitude
+ // of the diffs, which are <= 0.
+ const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
+ input_beta_left_shift);
+
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ optimized_ops::Softmax(params, shape_common, input_data, shape_common,
+ optimized_softmax_output.data());
+ reference_ops::Softmax(params, shape_common, input_data, shape_common,
+ reference_quant_softmax_output.data());
+
+ CheckOutputData(optimized_softmax_output.data(),
+ reference_float_softmax_output.data(), shape_common,
+ "Optimized vs float reference", false);
+ CheckOutputData(optimized_softmax_output.data(),
+ reference_quant_softmax_output.data(), shape_common,
+ "Optimized vs quant reference", true);
+ CheckOutputData(reference_quant_softmax_output.data(),
+ reference_float_softmax_output.data(), shape_common,
+ "Quant reference vs float reference", false);
+}
+
+// This function picks some random Softmax params, which are checked for
+// desirability. If not acceptable, it returns false. If they're OK,
+// it runs the Softmax test and returns true. This allows the caller
+// to loop until a test has been run.
+//
+// Currently we do not reject for any reason.
+bool TryOneUniformSoftmax() {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // Softmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10);
+
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandom(&input_data);
+ RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale,
+ stride, beta);
+ return true;
+}
+
+// See TryOneUniformSoftmax() for a general description.
+//
+// Tests with "skyscraper" input patterns are included for two reasons. (a)
+// Bimodal distributions are potentially challenging and perhaps more
+// realistic than simple uniform random inputs. (b) Some implementations of
+// Softmax may adapt as they traverse the depth, and so we test handling of
+// cases where relatively small values are encountered at the beginning and end.
+bool TryOneSkyscraperSoftmax(bool small_depth) {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // Softmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = small_depth
+ ? ExponentialRandomPositiveInt(0.75f, 40, 500)
+ : ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10);
+ // Extra parameters for skyscraper input patterns.
+ const double middle_proportion =
+ ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0);
+ const int middle_min = UniformRandomInt(0, 255);
+ const int sides_max = UniformRandomInt(0, middle_min);
+
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
+ sides_max);
+ RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale,
+ stride, beta);
+ return true;
+}
+
+TEST(TestQuantizedSoftmax, UniformSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneUniformSoftmax()) {
+ }
+ }
+}
+
+TEST(TestQuantizedSoftmax, SkyscraperSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperSoftmax(false)) {
+ }
+ }
+}
+
+TEST(TestQuantizedSoftmax, SmallSkyscraperSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperSoftmax(true)) {
+ }
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
index 4eddf7bf0a..20abcb7258 100644
--- a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
@@ -43,13 +43,13 @@ bool Spectrogram::Initialize(int window_length, int step_length) {
return Initialize(window, step_length);
}
-inline int Log2Floor(uint n) {
+inline int Log2Floor(uint32_t n) {
if (n == 0) return -1;
int log = 0;
- uint value = n;
+ uint32_t value = n;
for (int i = 4; i >= 0; --i) {
int shift = (1 << i);
- uint x = value >> shift;
+ uint32_t x = value >> shift;
if (x != 0) {
value = x;
log += shift;
@@ -58,7 +58,7 @@ inline int Log2Floor(uint n) {
return log;
}
-inline int Log2Ceiling(uint n) {
+inline int Log2Ceiling(uint32_t n) {
int floor = Log2Floor(n);
if (n == (n & ~(n - 1))) // zero or a power of two
return floor;
@@ -66,7 +66,7 @@ inline int Log2Ceiling(uint n) {
return floor + 1;
}
-inline uint NextPowerOfTwo(uint value) {
+inline uint32_t NextPowerOfTwo(uint32_t value) {
int exponent = Log2Ceiling(value);
// DCHECK_LT(exponent, std::numeric_limits<uint32>::digits);
return 1 << exponent;
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index ef77371bf6..af5db1064c 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-
namespace strided_slice {
// Use until std::clamp() is available from C++17.
@@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) {
return v;
}
+inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
+ int dim_count) {
+ // Add indices and mask bits to fully include extra dimensions
+ TFLITE_CHECK_LE(dim_count, 4);
+ TFLITE_CHECK_GE(dim_count, p->start_indices_count);
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ const int pad_count = dim_count - p->start_indices_count;
+
+ // Pad indices at start, so move arrays by pad_count.
+ for (int i = p->start_indices_count - 1; i > 0; --i) {
+ p->strides[i + pad_count] = p->strides[i];
+ p->start_indices[i + pad_count] = p->start_indices[i];
+ p->stop_indices[i + pad_count] = p->stop_indices[i];
+ }
+ for (int i = 0; i < pad_count; ++i) {
+ p->start_indices[i] = 0;
+ p->stop_indices[i] = 0;
+ p->strides[i] = 1;
+ }
+
+ // Pad masks with 0s or 1s as required.
+ p->shrink_axis_mask <<= pad_count;
+ p->ellipsis_mask <<= pad_count;
+ p->new_axis_mask <<= pad_count;
+ p->begin_mask <<= pad_count;
+ p->end_mask <<= pad_count;
+ p->begin_mask |= (1 << pad_count) - 1;
+ p->end_mask |= (1 << pad_count) - 1;
+
+ p->start_indices_count = dim_count;
+ p->stop_indices_count = dim_count;
+ p->strides_count = dim_count;
+}
+
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size - 1] that can be used to index
// directly into the data.
-template <typename IntType>
-inline int StartForAxis(int begin_mask,
- std::vector<IntType> const& start_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
- // Begin with the specified index
+inline int StartForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis) {
+ const auto begin_mask = params.begin_mask;
+ const auto* start_indices = params.start_indices;
+ const auto* strides = params.strides;
+ // Begin with the specified index.
int start = start_indices[axis];
// begin_mask override
@@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ int axis_size = input_shape.Dims(axis);
if (start < 0) {
start += axis_size;
}
@@ -73,13 +109,26 @@ inline int StartForAxis(int begin_mask,
// element. ie. So if you were iterating through all elements of a 1D array of
// size 4, this function would return 4 as the stop, because it is one past the
// "real" indices of 0, 1, 2 & 3.
-template <typename IntType>
-inline int StopForAxis(int end_mask, std::vector<IntType> const& stop_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
+inline int StopForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis,
+ int start_for_axis) {
+ const auto end_mask = params.end_mask;
+ const auto shrink_axis_mask = params.shrink_axis_mask;
+ const auto* stop_indices = params.stop_indices;
+ const auto* strides = params.strides;
+
// Begin with the specified index
+ const bool shrink_axis = shrink_axis_mask & (1 << axis);
int stop = stop_indices[axis];
+ // When shrinking an axis, the end position does not matter (and can be
+ // incorrect when negative indexing is used, see Issue #19260). Always use
+ // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
+ // already been adjusted for negative indices.
+ if (shrink_axis) {
+ stop = start_for_axis + 1;
+ }
+
// end_mask override
if (end_mask & (1 << axis)) {
if (strides[axis] > 0) {
@@ -93,7 +142,7 @@ inline int StopForAxis(int end_mask, std::vector<IntType> const& stop_indices,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ const int axis_size = input_shape.Dims(axis);
if (stop < 0) {
stop += axis_size;
}
@@ -117,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) {
return stride > 0 ? index >= stop : index <= stop;
}
+inline tflite::StridedSliceParams BuildStridedSliceParams(
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
+ const std::vector<int>& strides) {
+ tflite::StridedSliceParams op_params;
+ const int dims_count = start_indices.size();
+
+ op_params.start_indices_count = dims_count;
+ op_params.stop_indices_count = dims_count;
+ op_params.strides_count = dims_count;
+ for (int i = 0; i < dims_count; ++i) {
+ op_params.start_indices[i] = start_indices[i];
+ op_params.stop_indices[i] = stop_indices[i];
+ op_params.strides[i] = strides[i];
+ }
+
+ op_params.begin_mask = begin_mask;
+ op_params.ellipsis_mask = 0;
+ op_params.end_mask = end_mask;
+ op_params.new_axis_mask = 0;
+ op_params.shrink_axis_mask = shrink_axis_mask;
+
+ return op_params;
+}
+
} // namespace strided_slice
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ce887cea8b..689cea03e7 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -15,103 +15,30 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#include <complex>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-template <typename T>
-inline T* GetTensorData(TfLiteTensor* tensor);
-
-template <>
-inline float* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline int32_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline int64_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline bool* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
-template <typename T>
-inline const T* GetTensorData(const TfLiteTensor* tensor);
-
template <>
-inline const float* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
+inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr
+ ? reinterpret_cast<std::complex<float>*>(tensor->data.c64)
+ : nullptr;
}
template <>
-inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
+inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr
+ ? reinterpret_cast<const std::complex<float>*>(tensor->data.c64)
+ : nullptr;
}
-template <>
-inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline const bool* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
-inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
- return GetTensorDims(data.data(), data.size());
-}
-
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
+inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
+ return RuntimeShape(data.size(), data.data());
}
// A list of tensors in a format that can be used by kernels like split and
@@ -125,20 +52,20 @@ class VectorOfTensors {
int num_tensors = tensor_list.size;
all_data_.reserve(num_tensors);
- all_dims_.reserve(num_tensors);
- all_dims_ptr_.reserve(num_tensors);
+ all_shape_.reserve(num_tensors);
+ all_shape_ptr_.reserve(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
all_data_.push_back(GetTensorData<T>(t));
- all_dims_.push_back(GetTensorDims(t));
+ all_shape_.push_back(GetTensorShape(t));
}
// Taking the pointer from inside a std::vector is only OK if the vector is
- // never modified, so we populate all_dims in the previous loop and then we
+ // never modified, so we populate all_shape in the previous loop and then we
// are free to grab iterators here.
for (int i = 0; i < num_tensors; ++i) {
- all_dims_ptr_.push_back(&all_dims_[i]);
+ all_shape_ptr_.push_back(&all_shape_[i]);
}
}
// Return a pointer to the data pointers of all tensors in the list. For
@@ -147,16 +74,16 @@ class VectorOfTensors {
// f[0][1] is the second element of the first tensor.
T* const* data() const { return all_data_.data(); }
- // Return a pointer the dim pointers of all tensors in the list. For
+ // Return a pointer the shape pointers of all tensors in the list. For
// example:
- // const Dims<4>* const* d = v.dims();
+ // const RuntimeShape* const* d = v.dims();
// dims[1] are the dimensions of the second tensor in the list.
- const Dims<4>* const* dims() const { return all_dims_ptr_.data(); }
+ const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
private:
std::vector<T*> all_data_;
- std::vector<Dims<4>> all_dims_;
- std::vector<Dims<4>*> all_dims_ptr_;
+ std::vector<RuntimeShape> all_shape_;
+ std::vector<RuntimeShape*> all_shape_ptr_;
};
// A list of quantized tensors in a format that can be used by kernels like
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
new file mode 100644
index 0000000000..9f5b33d217
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ TfLiteIntArray* dims = tensor->dims;
+ const int dims_size = dims->size;
+ const int32_t* dims_data = dims->data;
+ return RuntimeShape(dims_size, dims_data);
+}
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
index bf2068d320..2ed73ba82d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -21,28 +21,32 @@ namespace {
using ::testing::ElementsAre;
-TEST(TensorTest, GetTensorDims4D) {
- Dims<4> d = GetTensorDims({2, 3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape4D) {
+ RuntimeShape d = GetTensorShape({2, 3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(2, 3, 4, 5));
}
-TEST(TensorTest, GetTensorDims3D) {
- Dims<4> d = GetTensorDims({3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape3D) {
+ RuntimeShape d = GetTensorShape({3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(3, 4, 5));
}
-TEST(TensorTest, GetTensorDims2D) {
- Dims<4> d = GetTensorDims({4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+TEST(TensorTest, GetTensorShape2D) {
+ RuntimeShape d = GetTensorShape({4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(4, 5));
}
-TEST(TensorTest, GetTensorDims1D) {
- Dims<4> d = GetTensorDims({5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+TEST(TensorTest, GetTensorShape1D) {
+ RuntimeShape d = GetTensorShape({5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(5));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index e1c9ccd84b..b0fe5adf65 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -15,7 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
namespace tflite {
namespace tensor_utils {
@@ -23,13 +27,16 @@ namespace tensor_utils {
// Limit a float input f between +abs_limit and -abs_limit.
float Clip(float f, float abs_limit);
+// Checks if all entries of vector are zero.
+bool IsZeroVector(const float* vector, int v_size);
+
// Quantizes a buffer of floating point values using a symmetric quantization
// (i.e. linear quantization without an offset) to 8-bit signed integers.
// It also outputs the range (min, max) of the floating point buffer, and the
// scaling factor used to quantize the values.
void SymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min, float* max,
- float* scaling_factor);
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor);
// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
// dimension composed by input vectors independent from each other). The result
@@ -94,6 +101,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -101,6 +113,10 @@ void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result);
+// Add another vector for each batch in the batch vector.
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector);
@@ -121,6 +137,10 @@ void Sub1Vector(const float* vector, int v_size, float* result);
// Fill vector with 0.f.
void ZeroVector(float* vector, int v_size);
+// Multiply all elements of vector with a scalar.
+void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
+ float* result);
+
// Clip elements of a vector using a abs_limit value.
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result);
@@ -136,6 +156,12 @@ void VectorShiftLeft(float* vector, int v_size, float shift_value);
// added to get one element of output.
void ReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon);
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 3d8a2eada0..6458af714b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include <gmock/gmock.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
@@ -32,19 +32,55 @@ TEST(uKernels, ClipTest) {
{0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0})));
}
+TEST(uKernels, VectorScalarMultiply) {
+ constexpr int kVectorSize = 29;
+ static int8_t input[kVectorSize];
+ for (int i = 0; i < 29; ++i) {
+ input[i] = static_cast<int8_t>(i - 14);
+ }
+ const float scale = 0.1f;
+ std::vector<float> output(kVectorSize, 0.0f);
+ VectorScalarMultiply(input, kVectorSize, scale, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5,
+ -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4})));
+}
+
+TEST(uKernels, IsZeroTest) {
+ constexpr int kVectorSize = 21;
+ static float zeros[kVectorSize] = {0.0};
+ EXPECT_TRUE(IsZeroVector(zeros, kVectorSize));
+
+ static float nonzeros[kVectorSize] = {
+ 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12,
+ 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, 1e-19,
+ 1e-20, 1e-21, 1e-22, 1e-23, 1e-24, 1e-25, 1e-26};
+ EXPECT_FALSE(IsZeroVector(nonzeros, kVectorSize));
+}
+
+TEST(uKernels, GeneratedIsZeroTest) {
+ constexpr int kVectorSize = 39;
+ std::vector<float> input(kVectorSize);
+ ZeroVector(input.data(), kVectorSize);
+ EXPECT_TRUE(IsZeroVector(input.data(), kVectorSize));
+}
+
TEST(uKernels, SymmetricQuantizeFloatsTest) {
constexpr int kVectorSize = 9;
static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0,
-5.0, -10.0, 0.0, 1000.0};
- int8 output[kVectorSize];
+ int8_t output[kVectorSize];
float min, max, scaling_factor;
SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
&scaling_factor);
EXPECT_EQ(min, -640);
EXPECT_EQ(max, 1000);
- EXPECT_NEAR(scaling_factor, 0.127, 1e-6); // EQ won't work due to fpoint.
+ // EQ won't work due to fpoint.
+ EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6);
EXPECT_THAT(output,
testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127}));
}
@@ -53,7 +89,7 @@ TEST(uKernels, SymmetricQuantizeFloatsAllZerosTest) {
constexpr int kVectorSize = 9;
static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
- int8 output[kVectorSize];
+ int8_t output[kVectorSize];
float min, max, scaling_factor;
SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
&scaling_factor);
@@ -69,14 +105,14 @@ TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) {
static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6,
4e-5, 9e-6, 2e-4, 0};
- int8 output[kVectorSize];
+ int8_t output[kVectorSize];
float min, max, scaling_factor;
SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
&scaling_factor);
EXPECT_NEAR(min, -9e-05, 1e-6);
EXPECT_NEAR(max, 0.0002, 1e-6);
- EXPECT_EQ(scaling_factor, 635000);
+ EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6);
EXPECT_THAT(output,
testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0}));
}
@@ -107,6 +143,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
-1., 3., 7., 3., 23., 3.})));
}
+#ifdef __ANDROID__
TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
// Note we use 29 columns as this exercises all the neon kernel: the
// 16-block SIMD code, the 8-block postamble, and the leftover postamble.
@@ -130,13 +167,13 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
-13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
-23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
- int8* a_int8_data = reinterpret_cast<int8*>(
+ int8_t* a_int8_data = reinterpret_cast<int8_t*>(
aligned_malloc(a_rows * a_cols, kWeightsPerUint32));
float a_min, a_max;
float scaling_factor_a;
SymmetricQuantizeFloats(a_float_data, a_rows * a_cols, a_int8_data, &a_min,
&a_max, &scaling_factor_a);
- const int8 expected_a_int8_data[] = {
+ const int8_t expected_a_int8_data[] = {
/* 1st row */
5,
10,
@@ -327,7 +364,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
};
// Quantized values of B:
- int8 b_int8_data[b_rows * b_cols * batches];
+ int8_t b_int8_data[b_rows * b_cols * batches];
float b_min, b_max;
float scaling_factor_b[batches];
SymmetricQuantizeFloats(b_float_data, b_rows * b_cols, b_int8_data, &b_min,
@@ -336,7 +373,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
&b_int8_data[b_rows * b_cols], &b_min, &b_max,
&scaling_factor_b[1]);
- const int8 expected_b_int8_data[] = {
+ const int8_t expected_b_int8_data[] = {
/* batch 1 */
127,
-127,
@@ -429,6 +466,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
aligned_free(a_int8_data);
}
+#endif // __ANDROID__
TEST(uKernels, VectorVectorCwiseProductTest) {
constexpr int kVectorSize = 10;
@@ -458,6 +496,16 @@ TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
{1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
}
+TEST(uKernels, VectorBatchVectorAddTest) {
+ constexpr int kVectorSize = 3;
+ constexpr int kBatchSize = 2;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0};
+ std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
+ EXPECT_THAT(output,
+ testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
+}
+
TEST(uKernels, VectorBatchVectorAssignTest) {
constexpr int kVectorSize = 5;
constexpr int kBatchSize = 3;
@@ -517,6 +565,120 @@ TEST(uKernels, ZeroVectorTest) {
ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0})));
}
+TEST(uKernels, VectorBatchVectorCwiseProductAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProductAccumulate(input, kVectorSize, output.data(),
+ kBatchSize, output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 2.310000, 7.040000, 14.190000, 23.760000, 35.750000, 50.159996, 66.989998,
+ 86.240005, 107.909996, 112.110008, 134.542084, 159.014389, 185.526901,
+ 214.079605, 244.672485, 277.305603, 311.978912, 348.692413, 387.446136,
+ 428.240051, 471.074066, 515.948364, 562.862854, 611.817566, 662.812500,
+ 715.847595, 770.922974, 828.038452, 0.000000,
+ /* batch 1 */
+ -2.310000, -7.040000, -14.190000, -23.760000, -35.750000, -50.159996,
+ -66.989998, -86.240005, -107.909996, -112.110008, -134.542084,
+ -159.014389, -185.526901, -214.079605, -244.672485, -277.305603,
+ -311.978912, -348.692413, -387.446136, -428.240051, -471.074066,
+ -515.948364, -562.862854, -611.817566, -662.812500, -715.847595,
+ -770.922974, -828.038452, 0.000000,
+ /* batch 2 */
+ 2.310000, -7.040000, 14.190000, -23.760000, 35.750000, -50.159996,
+ 66.989998, -86.240005, 107.909996, -112.110008, 134.542084, -159.014389,
+ 185.526901, -214.079605, 244.672485, -277.305603, 311.978912, -348.692413,
+ 387.446136, -428.240051, 471.074066, -515.948364, 562.862854, -611.817566,
+ 662.812500, -715.847595, 770.922974, -828.038452, 0.000000,
+ /* batch 3 */
+ -2.310000, 7.040000, -14.190000, 23.760000, -35.750000, 50.159996,
+ -66.989998, 86.240005, -107.909996, 112.110008, -134.542084, 159.014389,
+ -185.526901, 214.079605, -244.672485, 277.305603, -311.978912, 348.692413,
+ -387.446136, 428.240051, -471.074066, 515.948364, -562.862854, 611.817566,
+ -662.812500, 715.847595, -770.922974, 828.038452, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, VectorBatchVectorCwiseProductNoAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProduct(input, kVectorSize, output.data(), kBatchSize,
+ output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 1.210000, 4.840000, 10.889999, 19.360001, 30.250000, 43.559998, 59.289997,
+ 77.440002, 98.009995, 102.010010, 123.432091, 146.894394, 172.396896,
+ 199.939606, 229.522491, 261.145599, 294.808899, 330.512421, 368.256134,
+ 408.040039, 449.864075, 493.728363, 539.632874, 587.577576, 637.562500,
+ 689.587585, 743.652954, 799.758423, 0.000000,
+ /* batch 1 */
+ -1.210000, -4.840000, -10.889999, -19.360001, -30.250000, -43.559998,
+ -59.289997, -77.440002, -98.009995, -102.010010, -123.432091, -146.894394,
+ -172.396896, -199.939606, -229.522491, -261.145599, -294.808899,
+ -330.512421, -368.256134, -408.040039, -449.864075, -493.728363,
+ -539.632874, -587.577576, -637.562500, -689.587585, -743.652954,
+ -799.758423, 0.000000,
+ /* batch 2 */
+ 1.210000, -4.840000, 10.889999, -19.360001, 30.250000, -43.559998,
+ 59.289997, -77.440002, 98.009995, -102.010010, 123.432091, -146.894394,
+ 172.396896, -199.939606, 229.522491, -261.145599, 294.808899, -330.512421,
+ 368.256134, -408.040039, 449.864075, -493.728363, 539.632874, -587.577576,
+ 637.562500, -689.587585, 743.652954, -799.758423, 0.000000,
+ /* batch 3 */
+ -1.210000, 4.840000, -10.889999, 19.360001, -30.250000, 43.559998,
+ -59.289997, 77.440002, -98.009995, 102.010010, -123.432091, 146.894394,
+ -172.396896, 199.939606, -229.522491, 261.145599, -294.808899, 330.512421,
+ -368.256134, 408.040039, -449.864075, 493.728363, -539.632874, 587.577576,
+ -637.562500, 689.587585, -743.652954, 799.758423, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
constexpr int kVectorSize = 5;
constexpr int kBatch = 2;
@@ -560,5 +722,85 @@ TEST(uKernels, ReductionSumVectorTest) {
EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
}
+TEST(uKernels, MeanStddevNormalizationNoneZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // None-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.1, 0.2, 0.3, 0.4, // batch 0
+ 0.9, 1.0, 1.1, 1.2, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 0
+ -1.34163153, -0.447210163, 0.447211236, 1.3416326, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationAllZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationMixed) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.1, 0.2, 0.3, 0.4, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationSmallValue) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 3e-5, -7e-6, -9e-5, 1e-6, // batch 0
+ 4e-5, 9e-6, 2e-4, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 1.04231524, 0.212946132, -1.64753067, 0.392269224, // batch 0
+ -0.275023013, -0.658201098, 1.70267045, -0.769446373, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc
new file mode 100644
index 0000000000..75d568ae3a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc
@@ -0,0 +1,107 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+
+#include <cmath>
+#include <iterator>
+
+namespace tflite {
+
+// this is a copied from an internal function in propagate_fixed_sizes.cc
+bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth,
+ int filter_width, int filter_height, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type, RuntimeShape* output_shape,
+ int* pad_width, int* pad_height) {
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int batch = input_shape.Dims(0);
+
+ int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1;
+ int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1;
+
+ int output_height = 0;
+ int output_width = 0;
+ if (padding_type == PaddingType::kValid) {
+ output_height = (input_height + stride - dilated_filter_height) / stride;
+ output_width = (input_width + stride - dilated_filter_width) / stride;
+ } else if (padding_type == PaddingType::kSame) {
+ output_height = (input_height + stride - 1) / stride;
+ output_width = (input_width + stride - 1) / stride;
+ } else {
+ return false;
+ }
+
+ if (output_width <= 0 || output_height <= 0) {
+ return false;
+ }
+
+ *pad_height = std::max(
+ 0, ((output_height - 1) * stride + dilated_filter_height - input_height) /
+ 2);
+ *pad_width = std::max(
+ 0,
+ ((output_width - 1) * stride + dilated_filter_width - input_width) / 2);
+
+ output_shape->BuildFrom({batch, output_height, output_width, output_depth});
+ return true;
+}
+
+std::mt19937& RandomEngine() {
+ static std::mt19937 engine;
+ return engine;
+}
+
+int UniformRandomInt(int min, int max) {
+ std::uniform_int_distribution<int> dist(min, max);
+ return dist(RandomEngine());
+}
+
+float UniformRandomFloat(float min, float max) {
+ std::uniform_real_distribution<float> dist(min, max);
+ return dist(RandomEngine());
+}
+
+int ExponentialRandomPositiveInt(float percentile, int percentile_val,
+ int max_val) {
+ const float lambda =
+ -std::log(1.f - percentile) / static_cast<float>(percentile_val);
+ std::exponential_distribution<float> dist(lambda);
+ float val;
+ do {
+ val = dist(RandomEngine());
+ } while (!val || !std::isfinite(val) || val > max_val);
+ return static_cast<int>(std::ceil(val));
+}
+
+float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
+ float max_val) {
+ const float lambda =
+ -std::log(1.f - percentile) / static_cast<float>(percentile_val);
+ std::exponential_distribution<float> dist(lambda);
+ float val;
+ do {
+ val = dist(RandomEngine());
+ } while (!std::isfinite(val) || val > max_val);
+ return val;
+}
+
+void FillRandom(std::vector<float>* vec, float min, float max) {
+ std::uniform_real_distribution<float> dist(min, max);
+ auto gen = std::bind(dist, RandomEngine());
+ std::generate(std::begin(*vec), std::end(*vec), gen);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h
new file mode 100644
index 0000000000..e4a383bedf
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.h
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_
+
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <vector>
+
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+// Computes output and padding dimensions.
+bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth,
+ int filter_width, int filter_height, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type, RuntimeShape* output_shape,
+ int* pad_width, int* pad_height);
+
+// Returns a mt19937 random engine.
+std::mt19937& RandomEngine();
+
+// Returns a random integer uniformly distributed between |min| and |max|.
+int UniformRandomInt(int min, int max);
+
+// Returns a random float uniformly distributed between |min| and |max|.
+float UniformRandomFloat(float min, float max);
+
+// Returns a random element in |v|.
+template <typename T>
+const T& RandomElement(const std::vector<T>& v) {
+ return v[UniformRandomInt(0, v.size() - 1)];
+}
+
+// Returns a random exponentially distributed integer.
+int ExponentialRandomPositiveInt(float percentile, int percentile_val,
+ int max_val);
+
+// Returns a random exponentially distributed float.
+float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
+ float max_val);
+
+// Fills a vector with random floats between |min| and |max|.
+void FillRandom(std::vector<float>* vec, float min, float max);
+
+// Fills a vector with random numbers between |min| and |max|.
+template <typename T>
+void FillRandom(std::vector<T>* vec, T min, T max) {
+ std::uniform_int_distribution<T> dist(min, max);
+ auto gen = std::bind(dist, RandomEngine());
+ std::generate(std::begin(*vec), std::end(*vec), gen);
+}
+
+// Fills a vector with random numbers.
+template <typename T>
+void FillRandom(std::vector<T>* vec) {
+ FillRandom(vec, std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+}
+
+template <typename T>
+void FillRandom(typename std::vector<T>::iterator begin_it,
+ typename std::vector<T>::iterator end_it, T min, T max) {
+ std::uniform_int_distribution<T> dist(min, max);
+ auto gen = std::bind(dist, RandomEngine());
+ std::generate(begin_it, end_it, gen);
+}
+
+// Fill with a "skyscraper" pattern, in which there is a central section (across
+// the depth) with higher values than the surround.
+template <typename T>
+void FillRandomSkyscraper(std::vector<T>* vec, int depth,
+ double middle_proportion, uint8 middle_min,
+ uint8 sides_max) {
+ for (auto base_it = std::begin(*vec); base_it != std::end(*vec);
+ base_it += depth) {
+ auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion));
+ auto right_it =
+ base_it + std::ceil(0.5 * depth * (1.0 + middle_proportion));
+ FillRandom(base_it, left_it, std::numeric_limits<T>::min(), sides_max);
+ FillRandom(left_it, right_it, middle_min, std::numeric_limits<T>::max());
+ FillRandom(right_it, base_it + depth, std::numeric_limits<T>::min(),
+ sides_max);
+ }
+}
+
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 43c6883278..c6bc6074d4 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,11 +15,81 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#include <algorithm>
+#include <cstring>
+
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
+enum class PaddingType : uint8 { kNone, kSame, kValid };
+
+struct PaddingValues {
+ int16 width;
+ int16 height;
+};
+
+// This enumeration allows for non-default formats for the weights array
+// of a fully-connected operator, allowing the use of special optimized
+// runtime paths.
+enum class FullyConnectedWeightsFormat : uint8 {
+ // Default format (flat 2D layout, the inner contiguous dimension
+ // is input_depth, the outer non-contiguous dimension is output_depth)
+ kDefault,
+ // Summary: optimized layout for fast CPU runtime implementation,
+ // aimed specifically at ARM CPUs at the moment, and specialized for
+ // 8-bit quantized layers.
+ //
+ // The use case we're concerned with here is: 8-bit quantization,
+ // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
+ // a key application that drove this), very small batch size (e.g. 1 -- 4).
+ //
+ // Even with 8-bit quantization of weights, the performance of memory
+ // accesses to the weights can become the dominant issue when
+ // the batch size is small, so each weight value is used in only a few
+ // arithmetic ops, i.e. the fully-connected node has a low arithmetic
+ // intensity. The specific issues that arise are of three kinds:
+ // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
+ // bound. That's the "good" issue to run into.
+ // (2) One may run into sub-optimal pre-fetching: the data hasn't been
+ // prefetched into the cache by the time we need it.
+ // (3) One may run into cache aliasing: multiple values that are
+ // pre-fetched, alias each other in the L1 cache (which typically
+ // has only 4-way set associativity in ARM CPUs) and thus evict
+ // each other before we get to using them.
+ //
+ // The point of this shuffling is to avoid issues (2) and (3) so that
+ // we get as fast as possible given only the hard constraint (1).
+ // This is achieved by turning the difficulty into a solution: the
+ // difficulty, that each value loaded from memory is used only in
+ // one kernel iteration, making this operation memory-intensive, hints at
+ // the solution, of shuffling the weights so that they are stored in the
+ // exact order as the kernel needs to load them, so that the memory
+ // accesses made by the kernel are trivial. This solves (2) because the
+ // trivial memory access pattern allows the CPU's automatic prefetching
+ // to perform very well (no need even for preload instructions), and this
+ // solves (3) because the values being loaded concurrently are now
+ // contiguous in the address space, thus don't alias each other in the cache.
+ //
+ // On ARM, we typically want our kernel to process a 4x16 block of weights
+ // at a time, because:
+ // - 16 is the number of bytes in a NEON register.
+ // - 4 is how many rows we need to handle concurrently in the kernel in
+ // order to have sufficient mutual independence of instructions to
+ // maximize arithmetic throughput.
+ //
+ // Finally, the 'Int8' part in the name refers to the fact that this
+ // weights format has each weights value encoded as a signed int8 value,
+ // even if the data type of the weights buffer is uint8. This is intended
+ // to save runtime kernels the effort to have to XOR the top bit of these
+ // bytes before using them in signed arithmetic, see this file for more
+ // explanations on the 'signed int8 trick' in matrix multiplication kernels:
+ //
+ // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+ //
+ kShuffled4x16Int8,
+};
// Quantization parameters, determining the mapping of quantized values
// to real values (i.e. determining how quantized values are mathematically
@@ -43,9 +113,207 @@ struct Dims {
int strides[N];
};
+class RuntimeShape {
+ public:
+ // Shapes with dimensions up to 4 are stored directly in the structure, while
+ // larger shapes are separately allocated.
+ static constexpr int kMaxSmallSize = 4;
+
+ RuntimeShape& operator=(RuntimeShape const&) = delete;
+
+ RuntimeShape() : size_(0) {}
+
+ explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
+ if (dimensions_count > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
+ dims_pointer_ = new int32[dimensions_count];
+#endif // TF_LITE_STATIC_MEMORY
+ }
+ }
+
+ RuntimeShape(int shape_size, int32 value) : size_(0) {
+ Resize(shape_size);
+ for (int i = 0; i < shape_size; ++i) {
+ SetDim(i, value);
+ }
+ }
+
+ RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
+ ReplaceWith(dimensions_count, dims_data);
+ }
+
+ RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
+ BuildFrom(init_list);
+ }
+
+ // Avoid using this constructor. We should be able to delete it when C++17
+ // rolls out.
+ RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_ = new int32[size_];
+ }
+ std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
+ }
+
+ bool operator==(const RuntimeShape& comp) const {
+ return this->size_ == comp.size_ &&
+ std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
+ }
+
+ ~RuntimeShape() {
+ if (size_ > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
+ delete[] dims_pointer_;
+#endif // TF_LITE_STATIC_MEMORY
+ }
+ }
+
+ inline int32 DimensionsCount() const { return size_; }
+ inline int32 Dims(int i) const {
+ TFLITE_DCHECK_GE(i, 0);
+ TFLITE_DCHECK_LT(i, size_);
+ return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
+ }
+ inline void SetDim(int i, int32 val) {
+ TFLITE_DCHECK_GE(i, 0);
+ TFLITE_DCHECK_LT(i, size_);
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_[i] = val;
+ } else {
+ dims_[i] = val;
+ }
+ }
+
+ inline int32* DimsData() {
+ return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+ }
+ inline const int32* DimsData() const {
+ return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+ }
+ // The caller must ensure that the shape is no bigger than 4-D.
+ inline const int32* DimsDataUpTo4D() const { return dims_; }
+
+ inline void Resize(int dimensions_count) {
+ if (size_ > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
+ delete[] dims_pointer_;
+#endif // TF_LITE_STATIC_MEMORY
+ }
+ size_ = dimensions_count;
+ if (dimensions_count > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
+ dims_pointer_ = new int32[dimensions_count];
+#endif // TF_LITE_STATIC_MEMORY
+ }
+ }
+
+ inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
+ Resize(dimensions_count);
+ int32* dst_dims = DimsData();
+ std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
+ }
+
+ template <typename T>
+ inline void BuildFrom(const T& src_iterable) {
+ const int dimensions_count =
+ std::distance(src_iterable.begin(), src_iterable.end());
+ Resize(dimensions_count);
+ int32* data = DimsData();
+ for (auto it : src_iterable) {
+ *data = it;
+ ++data;
+ }
+ }
+
+ // This will probably be factored out. Old code made substantial use of 4-D
+ // shapes, and so this function is used to extend smaller shapes. Note that
+ // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
+ // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
+ // inputs should already be 4-D, so this function should not be needed.
+ inline static RuntimeShape ExtendedShape(int new_shape_size,
+ const RuntimeShape& shape) {
+ return RuntimeShape(new_shape_size, shape, 1);
+ }
+
+ inline void BuildFrom(const std::initializer_list<int> init_list) {
+ BuildFrom<const std::initializer_list<int>>(init_list);
+ }
+
+ // Returns the total count of elements, that is the size when flattened into a
+ // vector.
+ inline int FlatSize() const {
+ int buffer_size = 1;
+ const int* dims_data = DimsData();
+ for (int i = 0; i < size_; i++) {
+ const int dim = dims_data[i];
+ TFLITE_DCHECK_GE(dim, 1);
+ buffer_size *= dim;
+ }
+ return buffer_size;
+ }
+
+ bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
+
+ private:
+ // For use only by ExtendedShape(), written to guarantee (return-value) copy
+ // elision in C++17.
+ // This creates a shape padded to the desired size with the specified value.
+ RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
+ : size_(0) {
+ // If the following check fails, it is likely because a 4D-only kernel is
+ // being used with an array of larger dimension count.
+ TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
+ Resize(new_shape_size);
+ const int size_increase = new_shape_size - shape.DimensionsCount();
+ for (int i = 0; i < size_increase; ++i) {
+ SetDim(i, pad_value);
+ }
+ std::memcpy(DimsData() + size_increase, shape.DimsData(),
+ sizeof(int32) * shape.DimensionsCount());
+ }
+
+ int32 size_;
+ union {
+ int32 dims_[kMaxSmallSize];
+ int32* dims_pointer_;
+ };
+};
+
+// Converts inference-style shape to legacy tflite::Dims<4>.
+inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
+ tflite::Dims<4> result;
+ const int dimensions_count = array_shape.DimensionsCount();
+ TFLITE_CHECK_LE(dimensions_count, 4);
+ int cum_prod = 1;
+ for (int i = 0; i < 4; i++) {
+ const int new_dim =
+ (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
+ result.sizes[i] = new_dim;
+ result.strides[i] = cum_prod;
+ cum_prod *= new_dim;
+ }
+ return result;
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
- TFLITE_DCHECK_GT(num_dims, 0);
+ if (num_dims == 0) {
+ return false;
+ }
TFLITE_DCHECK(dims != nullptr);
TFLITE_DCHECK(current != nullptr);
int carry = 1;
@@ -72,7 +340,9 @@ inline bool NextIndex(const int num_dims, const int* dims, int* current) {
inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
const int* index, const int num_axis,
const int* axis) {
- TFLITE_DCHECK_GT(num_dims, 0);
+ if (num_dims == 0) {
+ return 0;
+ }
TFLITE_DCHECK(dims != nullptr);
TFLITE_DCHECK(index != nullptr);
size_t offset = 0;
@@ -95,6 +365,16 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
return offset;
}
+inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
+ const int* dims_data = shape.DimsDataUpTo4D();
+ TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
+ return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
+}
+
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
@@ -108,7 +388,14 @@ inline int Offset(const Dims<4>& dims, int* index) {
return Offset(dims, index[0], index[1], index[2], index[3]);
}
+inline int Offset(const RuntimeShape& shape, int* index) {
+ return Offset(shape, index[0], index[1], index[2], index[3]);
+}
+
// Get array size, DCHECKing that the dim index is in range.
+//
+// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
+// already performs this check.
template <int N>
int ArraySize(const Dims<N>& array, int index) {
TFLITE_DCHECK(index >= 0 && index < N);
@@ -130,6 +417,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
return MatchingArraySize(array1, index1, args...);
}
+// Get common shape dim, DCHECKing that they all agree.
+inline int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return shape1.Dims(index1);
+}
+
+template <typename... Args>
+int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2, Args... args) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return MatchingDim(shape1, index1, args...);
+}
+
+// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
template <int N>
inline int FlatSize(const Dims<N>& dims) {
int flat_size = 1;
@@ -139,13 +441,61 @@ inline int FlatSize(const Dims<N>& dims) {
return flat_size;
}
-// Deprecated. Prefer FlatSize.
+TFLITE_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
// Flat size calculation, checking that dimensions match with one or more other
// arrays.
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return shape.FlatSize();
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1);
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
+}
+
+// Flat size calculation, checking that dimensions match with one or more other
+// arrays.
template <int N>
inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
for (int i = 0; i < N; ++i) {
@@ -170,7 +520,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
for (int i = 0; i < N; ++i) {
TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
}
- return FlatSize(dims, check_dims_1, check_dims_2);
+ return MatchingFlatSize(dims, check_dims_1, check_dims_2);
}
template <int N>
@@ -181,7 +531,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
for (int i = 0; i < N; ++i) {
TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
}
- return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
+ return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
}
// Data is required to be contiguous, and so many operators can use either the
@@ -249,6 +599,72 @@ inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
check_dims_3);
}
+// Data is required to be contiguous, and so many operators can use either the
+// full array flat size or the flat size with one dimension skipped (commonly
+// the depth).
+inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
+ const int dims_count = shape.DimensionsCount();
+ TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
+ const auto* dims_data = shape.DimsData();
+ int flat_size = 1;
+ for (int i = 0; i < dims_count; ++i) {
+ flat_size *= (i == skip_dim) ? 1 : dims_data[i];
+ }
+ return flat_size;
+}
+
+// A combination of MatchingFlatSize() and FlatSizeSkipDim().
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return FlatSizeSkipDim(shape, skip_dim);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
+ check_shape_3);
+}
+
template <int N>
bool IsPackedWithoutStrides(const Dims<N>& dims) {
int expected_stride = 1;
@@ -259,6 +675,344 @@ bool IsPackedWithoutStrides(const Dims<N>& dims) {
return true;
}
+template <int N>
+void ComputeStrides(Dims<N>* dims) {
+ dims->strides[0] = 1;
+ for (int d = 1; d < N; d++) {
+ dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
+ }
+}
+
+enum class BroadcastableOpCategory : uint8 {
+ kNone,
+ kNonBroadcast, // Matching input shapes.
+ kFirstInputBroadcastsFast, // Fivefold nested loops.
+ kSecondInputBroadcastsFast, // Fivefold nested loops.
+ kGenericBroadcast, // Fall-back.
+};
+
+struct MinMax {
+ float min;
+ float max;
+};
+static_assert(sizeof(MinMax) == 8, "");
+
+struct ActivationParams {
+ FusedActivationFunctionType activation_type;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+};
+
+// For Add, Sub, Mul ops.
+struct ArithmeticParams {
+ // Shape dependent / common to data / op types.
+ BroadcastableOpCategory broadcast_category;
+ // uint8 inference params.
+ int32 input1_offset;
+ int32 input2_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // Add / Sub, not Mul, uint8 inference params.
+ int left_shift;
+ int32 input1_multiplier;
+ int input1_shift;
+ int32 input2_multiplier;
+ int input2_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+
+ // Processed output dimensions.
+ // Let input "a" be the one that broadcasts in the faster-changing dimension.
+ // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
+ // {b0, b1, b2, b3, b4},
+ // broadcast_shape[4] = b0 = a0.
+ // broadcast_shape[3] = b1; a1 = 1.
+ // broadcast_shape[2] = b2 = a2.
+ // broadcast_shape[1] = a3; b3 = 1.
+ // broadcast_shape[0] = b4 = a4.
+ int broadcast_shape[5];
+};
+
+struct ConcatenationParams {
+ int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
+};
+
+struct ComparisonParams {
+ // uint8 inference params.
+ int left_shift;
+ int32 input1_offset;
+ int32 input1_multiplier;
+ int input1_shift;
+ int32 input2_offset;
+ int32 input2_multiplier;
+ int input2_shift;
+ // Shape dependent / common to inference types.
+ bool is_broadcast;
+};
+
+struct ConvParams {
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ // TODO(starka): This was just "stride", so check that width+height is OK.
+ int16 stride_width;
+ int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
+struct DepthToSpaceParams {
+ int32 block_size;
+};
+
+struct DepthwiseParams {
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int16 stride_width;
+ int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
+ int16 depth_multiplier;
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
+struct DequantizationParams {
+ double scale;
+ int32 zero_point;
+};
+
+struct FakeQuantParams {
+ MinMax minmax;
+ int32 num_bits;
+};
+
+struct FullyConnectedParams {
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+ FullyConnectedWeightsFormat weights_format;
+};
+
+struct GatherParams {
+ int16 input_rank;
+ int16 axis;
+};
+
+struct L2NormalizationParams {
+ // uint8 inference params.
+ int32 input_zero_point;
+};
+
+struct LocalResponseNormalizationParams {
+ int32 range;
+ double bias;
+ double alpha;
+ double beta;
+};
+
+struct LogisticParams {
+ // uint8 inference params.
+ int32 input_zero_point;
+ int32 input_range_radius;
+ int32 input_multiplier;
+ int input_left_shift;
+};
+
+struct LstmCellParams {
+ int32 weights_zero_point;
+ int32 accum_multiplier;
+ int accum_shift;
+ int state_integer_bits;
+};
+
+struct MeanParams {
+ int8 axis_count;
+ int16 axis[4];
+};
+
+struct PackParams {
+ int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
+};
+
+struct PadParams {
+ int8 left_padding_count;
+ int32 left_padding[4];
+ int8 right_padding_count;
+ int32 right_padding[4];
+};
+
+struct PoolParams {
+ FusedActivationFunctionType activation;
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int stride_height;
+ int stride_width;
+ int filter_height;
+ int filter_width;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
+struct ReshapeParams {
+ int8 shape_count;
+ int32 shape[4];
+};
+
+struct ResizeBilinearParams {
+ bool align_corners;
+};
+
+struct SliceParams {
+ int8 begin_count;
+ int32 begin[4];
+ int8 size_count;
+ int32 size[4];
+};
+
+struct SoftmaxParams {
+ // beta is not really used (not a Tensorflow parameter) and not implemented
+ // for LogSoftmax.
+ double beta;
+ // uint8 inference params. Used even when beta defaults to 1.0.
+ int32 input_multiplier;
+ int32 input_left_shift;
+ // Reverse scaling is only used by LogSoftmax.
+ int32 reverse_scaling_divisor;
+ int32 reverse_scaling_right_shift;
+ int diff_min;
+};
+
+struct SpaceToBatchParams {
+ // "Zero" padding for uint8 means padding with the output offset.
+ int32 output_offset;
+};
+
+struct SpaceToDepthParams {
+ int32 block_size;
+};
+
+struct SplitParams {
+ // Graphs that split into, say, 2000 nodes are encountered. The indices in
+ // OperatorEdges are of type uint16.
+ uint16 num_split;
+ int16 axis;
+};
+
+struct SqueezeParams {
+ int8 squeeze_dims_count;
+ int32 squeeze_dims[4];
+};
+
+struct StridedSliceParams {
+ int8 start_indices_count;
+ int16 start_indices[4];
+ int8 stop_indices_count;
+ int16 stop_indices[4];
+ int8 strides_count;
+ int16 strides[4];
+
+ int16 begin_mask;
+ int16 ellipsis_mask;
+ int16 end_mask;
+ int16 new_axis_mask;
+ int16 shrink_axis_mask;
+};
+
+struct TanhParams {
+ int32 input_zero_point;
+ int32 input_range_radius;
+ int32 input_multiplier;
+ int input_left_shift;
+};
+
+struct TransposeParams {
+ int8 perm_count;
+ int32 perm[4];
+};
+
+struct UnpackParams {
+ uint16 num_split;
+ int16 axis;
+};
+
+template <typename P>
+inline void SetActivationParams(float min, float max, P* params) {
+ params->float_activation_min = min;
+ params->float_activation_max = max;
+}
+
+template <typename P>
+inline void SetActivationParams(int32 min, int32 max, P* params) {
+ params->quantized_activation_min = min;
+ params->quantized_activation_max = max;
+}
+
+template <typename P>
+inline void GetActivationParams(const P& params, int32* min, int32* max) {
+ *min = params.quantized_activation_min;
+ *max = params.quantized_activation_max;
+}
+
+template <typename P>
+inline void GetActivationParams(const P& params, float* min, float* max) {
+ *min = params.float_activation_min;
+ *max = params.float_activation_max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index 239b533a17..503ef28459 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -37,19 +37,17 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <=
1e-6 * std::min(input_product_scale, bias_scale));
TF_LITE_ENSURE(context, input_product_scale >= 0);
- TF_LITE_ENSURE(context, input_product_scale < output_scale);
*multiplier = input_product_scale / output_scale;
return kTfLiteOk;
}
-void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
- TfLiteTensor* output, int32_t* act_min,
- int32_t* act_max) {
- const int32_t qmin = std::numeric_limits<uint8_t>::min();
- const int32_t qmax = std::numeric_limits<uint8_t>::max();
-
+namespace {
+void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation,
+ int32_t qmin, int32_t qmax,
+ TfLiteTensor* output,
+ int32_t* act_min, int32_t* act_max) {
const auto scale = output->params.scale;
const auto zero_point = output->params.zero_point;
@@ -71,29 +69,47 @@ void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
*act_max = qmax;
}
}
-
-void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
- float* activation_min,
- float* activation_max) {
- if (activation == kTfLiteActRelu) {
- *activation_min = 0.f;
- *activation_max = std::numeric_limits<float>::max();
- } else if (activation == kTfLiteActRelu6) {
- *activation_min = 0.f;
- *activation_max = 6.f;
- } else if (activation == kTfLiteActRelu1) {
- *activation_min = -1.f;
- *activation_max = 1.f;
+} // namespace
+
+TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
+ TfLiteFusedActivation activation,
+ TfLiteTensor* output,
+ int32_t* act_min,
+ int32_t* act_max) {
+ int32_t qmin = 0;
+ int32_t qmax = 0;
+ if (output->type == kTfLiteUInt8) {
+ qmin = std::numeric_limits<uint8_t>::min();
+ qmax = std::numeric_limits<uint8_t>::max();
+ } else if (output->type == kTfLiteInt16) {
+ qmin = std::numeric_limits<int16_t>::min();
+ qmax = std::numeric_limits<int16_t>::max();
} else {
- *activation_min = std::numeric_limits<float>::lowest();
- *activation_max = std::numeric_limits<float>::max();
+ TF_LITE_ENSURE(context, false);
}
+
+ CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min,
+ act_max);
+ return kTfLiteOk;
+}
+
+void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
+ TfLiteTensor* output, int32_t* act_min,
+ int32_t* act_max) {
+ const int32_t qmin = std::numeric_limits<uint8_t>::min();
+ const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+ CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min,
+ act_max);
}
bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
return TfLiteIntArrayEqual(input1->dims, input2->dims);
}
+// TODO(petewarden): Having macros around this is ugly, look at other strategies
+// before replicating this approach elsewhere.
+#ifndef TF_LITE_STATIC_MEMORY
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
@@ -112,5 +128,6 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
*output_shape = shape.release();
return kTfLiteOk;
}
+#endif // TF_LITE_STATIC_MEMORY
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index 82cded36f2..e9a5fd7a40 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -15,8 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include <algorithm>
+#include <limits>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
@@ -28,6 +31,11 @@ inline const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->inputs->data[index]];
}
+inline TfLiteTensor* GetVariableInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ return (tensor->is_variable) ? tensor : nullptr;
+}
inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->outputs->data[index]];
@@ -86,14 +94,35 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
TfLiteTensor* output,
double* multiplier);
-// Calculates the useful range of an activation layer given its activation
-// tensor.
+// Calculates the useful quantized range of an activation layer given its
+// activation tensor.
+TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
+ TfLiteFusedActivation activation,
+ TfLiteTensor* output,
+ int32_t* act_min,
+ int32_t* act_max);
void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
TfLiteTensor* output, int32_t* act_min,
int32_t* act_max);
-void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
- float* activation_min,
- float* activation_max);
+// Calculates the useful range of an activation layer given its activation
+// tensor.a
+template <typename T>
+void CalculateActivationRange(TfLiteFusedActivation activation,
+ T* activation_min, T* activation_max) {
+ if (activation == kTfLiteActRelu) {
+ *activation_min = 0;
+ *activation_max = std::numeric_limits<T>::max();
+ } else if (activation == kTfLiteActRelu6) {
+ *activation_min = 0;
+ *activation_max = 6;
+ } else if (activation == kTfLiteActRelu1) {
+ *activation_min = -1;
+ *activation_max = 1;
+ } else {
+ *activation_min = std::numeric_limits<T>::lowest();
+ *activation_max = std::numeric_limits<T>::max();
+ }
+}
// Return true if the given tensors have the same shape.
bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2);
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index 7cea63da87..e02d7df9ef 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -68,10 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = 0; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +83,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorDims(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorDims(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = input->params.zero_point; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<uint8>(input), GetTensorShape(output), \
+ GetTensorData<uint8>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -94,7 +98,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else {
- context->ReportError(context, "Inputs and outputs not all float types.");
+ context->ReportError(context, "Output type is %d, requires float.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
index 11cc666bad..070ed60040 100644
--- a/tensorflow/contrib/lite/kernels/l2norm_test.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel {
int output_;
};
-TEST(L2NormOpTest, SimpleTest) {
+TEST(L2NormOpTest, SimpleFloatTest) {
L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
@@ -76,7 +76,7 @@ TEST(L2NormOpTest, SimpleTest) {
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
-TEST(L2NormOpTest, MultipleBatchesTest) {
+TEST(L2NormOpTest, MultipleBatchFloatTest) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);
m.SetInput({
@@ -105,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) {
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
}
+TEST(L2NormOpTest, MultipleBatchUint8Test) {
+ L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
+
+ m.QuantizeAndPopulate<uint8_t>(m.input(),
+ {
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({
+ 58, 166, 173, 205, 83, 134, // batch 1
+ 58, 166, 173, 205, 83, 134, // batch 2
+ 58, 166, 173, 205, 83, 134, // batch 3
+ }));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
+ },
+ 0.1)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
new file mode 100644
index 0000000000..9739fd4514
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -0,0 +1,1316 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Layer Normalization LSTM op that applies normalization by mean and standard
+// deviation to the activation of the LSTM layers. Please see
+// https://arxiv.org/abs/1607.06450 for details.
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace layer_norm_lstm {
+
+// Struct to hold Layer Norm LSTM option data.
+struct OpData {
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+ int scratch_tensor_index;
+};
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kInputLayerNormWeightsTensor = 12;
+constexpr int kForgetLayerNormWeightsTensor = 13;
+constexpr int kCellLayerNormWeightsTensor = 14;
+constexpr int kOutputLayerNormWeightsTensor = 15;
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 16; // Optional
+constexpr int kForgetGateBiasTensor = 17;
+constexpr int kCellGateBiasTensor = 18;
+constexpr int kOutputGateBiasTensor = 19;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 20; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 21; // Optional
+
+// State tensors.
+constexpr int kInputActivationStateTensor = 22;
+constexpr int kInputCellStateTensor = 23;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Total number of scratch tensors for hybrid Op.
+constexpr int kTensorsToAdd = 7;
+
+// Small float to avoid divergence during calculation of deviation.
+const float kLayerNormEpsilon = 1e-8;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+
+ // Turn custom option data into flexbuffer map format.
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+
+ // Get activation function, cell_clip and proj_clip from the flexbuffer.
+ // TODO(b/113824099): make activation more generic.
+ assert(m["fused_activation_function"].ToString() == "TANH");
+ data->activation = kTfLiteActTanh;
+ data->cell_clip = m["cell_clip"].AsFloat();
+ data->proj_clip = m["proj_clip"].AsFloat();
+
+ // Populate scratch_tensor_index.
+ context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, op_data->cell_clip >= 0);
+ TF_LITE_ENSURE(context, op_data->proj_clip >= 0);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Making sure layer norm weights are not null and have the right dimension.
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ const bool projection_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projection_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ // Inferring batch size, number of outputs and number of cells from the
+ // input tensors.
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
+
+ // Get the pointer to output, activation_state and cell_state tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const TfLiteTensor* activation_state =
+ GetInput(context, node, kInputActivationStateTensor);
+ const TfLiteTensor* cell_state =
+ GetInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ // Resize the output tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+ output_size->data[0] = n_batch;
+ output_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ if (use_cifg) {
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ } else {
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6);
+ recovered_weights->type = kTfLiteFloat32;
+ recovered_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1);
+ recovered_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_weights,
+ recovered_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, float cell_clip, float proj_clip,
+ const TfLiteFusedActivation& activation, int n_batch, int n_cell,
+ int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ int n_batch, int n_cell, int n_input, int n_output,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// The LayerNormLSTM Op engine.
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
+ const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
+ const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
+ const float* recurrent_to_forget_weights_ptr =
+ recurrent_to_forget_weights->data.f;
+ const float* recurrent_to_cell_weights_ptr =
+ recurrent_to_cell_weights->data.f;
+ const float* recurrent_to_output_weights_ptr =
+ recurrent_to_output_weights->data.f;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_layer_norm_weight_ptr,
+ forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr,
+ output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, activation_state_ptr, cell_state_ptr,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_weights_ptr = recovered_weights->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr,
+ cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_weights_ptr, quantized_input_ptr,
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[node->inputs->data[kInputActivationStateTensor]];
+ TfLiteTensor* cell_state =
+ &context->tensors[node->inputs->data[kInputCellStateTensor]];
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_layer_norm_weights,
+ forget_layer_norm_weights, cell_layer_norm_weights,
+ output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, op_data->cell_clip,
+ op_data->proj_clip, op_data->activation, scratch_buffer,
+ activation_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_layer_norm_weights, forget_layer_norm_weights,
+ cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, op_data->cell_clip, op_data->proj_clip,
+ op_data->activation, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_weights, input_quantized,
+ activation_state_quantized, cell_state_quantized, activation_state,
+ cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+} // namespace layer_norm_lstm
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM() {
+ static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free,
+ layer_norm_lstm::Prepare,
+ layer_norm_lstm::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
new file mode 100644
index 0000000000..479f6a7d3c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -0,0 +1,664 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Layer Norm LSTM op.
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LayerNormLSTMOpModel : public SingleOpModel {
+ public:
+ LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights, bool use_projection_bias,
+ float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weight_type = TensorType_FLOAT32)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(weight_type);
+ }
+
+ input_to_forget_weights_ = AddInput(weight_type);
+ input_to_cell_weights_ = AddInput(weight_type);
+ input_to_output_weights_ = AddInput(weight_type);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(weight_type);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(weight_type);
+ recurrent_to_cell_weights_ = AddInput(weight_type);
+ recurrent_to_output_weights_ = AddInput(weight_type);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(weight_type);
+ }
+ cell_to_forget_weights_ = AddInput(weight_type);
+ cell_to_output_weights_ = AddInput(weight_type);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ input_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ output_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(weight_type);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ // Adding the 2 state tensors.
+ output_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ // Set up and pass in custom options using flexbuffer.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("cell_clip", cell_clip);
+ fbb.Int("proj_clip", proj_clip);
+ fbb.String("fused_activation_function", "TANH");
+ });
+ fbb.Finish();
+ SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM);
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ protected:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_layer_norm_weights_;
+ int forget_layer_norm_weights_;
+ int cell_layer_norm_weights_;
+ int output_layer_norm_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_state_;
+ int cell_state_;
+
+ int output_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
+ public:
+ HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights,
+ bool use_projection_bias, float cell_clip,
+ float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
+ use_peephole, use_projection_weights,
+ use_projection_bias, cell_clip, proj_clip,
+ input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLayerNormLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the Layer Norm LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> input_layer_norm_weights_;
+ std::initializer_list<float> forget_layer_norm_weights_;
+ std::initializer_list<float> cell_layer_norm_weights_;
+ std::initializer_list<float> output_layer_norm_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> layer_norm_lstm_input_;
+
+ // Compares output up to tolerance to the result of the layer_norm_lstm given
+ // the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ LayerNormLSTMOpModel* layer_norm_lstm,
+ float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = layer_norm_lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+
+ layer_norm_lstm->Invoke();
+
+ const int num_outputs = layer_norm_lstm->num_outputs();
+ std::vector<float> expected;
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ EXPECT_THAT(layer_norm_lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
+ : public BaseLayerNormLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
+ 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
+ -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
+
+ input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
+ -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
+ -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
+
+ input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
+ -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
+ -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
+
+ input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
+ -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
+ -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
+
+ input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
+
+ forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
+
+ cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
+
+ output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
+
+ recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
+ -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
+
+ recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
+ -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
+
+ recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
+ 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
+
+ recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
+ -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
+
+ cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
+
+ cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
+
+ cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
+
+ input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
+ forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
+ cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
+ output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
+
+ projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
+ 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
+
+ layer_norm_lstm_input_ = {
+ {// Batch0: 3 (input_sequence_size) * 5 (n_input)
+ 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
+ 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
+ 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
+
+ {// Batch1: 3 (input_sequence_size) * 5 (n_input)
+ 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
+ 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
+ 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
+ };
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ LayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ LayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ // Verify the final output.
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244077, 0.128027, -0.00170918, // seq 0
+ 0.0137642, 0.140751, 0.0395835, // seq 1
+ -0.00459231, 0.155278, 0.0837377, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00692428, 0.0848741, 0.063445, // seq 0
+ -0.00403912, 0.139963, 0.072681, // seq 1
+ 0.00752706, 0.161903, 0.0561371, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ HybridLayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ HybridLayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244576, 0.127847, -0.00181765, // seq 0
+ 0.0137518, 0.140892, 0.0402234, // seq 1
+ -0.0048839, 0.155096, 0.0840309, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00728636, 0.0843957, 0.0634786, // seq 0
+ -0.00448382, 0.139278, 0.0737372, // seq 1
+ 0.00734616, 0.161793, 0.0560238, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index c15a5170b8..334d2a2788 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -64,11 +64,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
- type::LocalResponseNormalization( \
- GetTensorData<float>(input), GetTensorDims(input), params->radius, \
- params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
- GetTensorDims(output))
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ tflite::LocalResponseNormalizationParams op_params; \
+ op_params.range = params->radius; \
+ op_params.bias = params->bias; \
+ op_params.alpha = params->alpha; \
+ op_params.beta = params->beta; \
+ type::LocalResponseNormalization( \
+ op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
}
@@ -77,7 +81,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_LOCAL_RESPONSE_NORM
} else {
- context->ReportError(context, "Inputs and outputs not all float types.");
+ context->ReportError(context, "Output type is %d, requires float.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
index 62820a2f51..1acc966cdc 100644
--- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
@@ -90,10 +90,10 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::LogSoftmax(input_buffer, input_dims,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ SoftmaxParams params;
+ tflite::reference_ops::LogSoftmax(params, input_shape, input_buffer,
+ input_shape, output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
new file mode 100644
index 0000000000..f770cb35d1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -0,0 +1,134 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace logical {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for logical op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Reinterprete the opaque data provided by user.
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteBool) {
+ context->ReportError(context, "Logical ops only support bool type.");
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
+ const std::function<bool(bool, bool)>& func) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (data->requires_broadcast) {
+ reference_ops::BroadcastLogical4DSlow(
+ GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output), func);
+ } else {
+ reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output),
+ func);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
+ const auto logical_or_func = std::logical_or<bool>();
+ return LogicalImpl(context, node, logical_or_func);
+}
+
+TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
+ const auto logical_and_func = std::logical_and<bool>();
+ return LogicalImpl(context, node, logical_and_func);
+}
+
+} // namespace
+} // namespace logical
+
+TfLiteRegistration* Register_LOGICAL_OR() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
+ logical::LogicalOrEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOGICAL_AND() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
+ logical::LogicalAndEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc
new file mode 100644
index 0000000000..206cbde98f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/logical_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class LogicalOpModel : public SingleOpModel {
+ public:
+ LogicalOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape, BuiltinOperator op) {
+ input1_ = AddInput(TensorType_BOOL);
+ input2_ = AddInput(TensorType_BOOL);
+ output_ = AddOutput(TensorType_BOOL);
+ ConfigureBuiltinOp(op);
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_LOGICAL_OR: {
+ SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions,
+ CreateLogicalOrOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LOGICAL_AND: {
+ SetBuiltinOp(op, BuiltinOptions_LogicalAndOptions,
+ CreateLogicalAndOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
+};
+
+TEST(LogicalTest, LogicalOr) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {true, false, true, false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(LogicalTest, BroadcastLogicalOr) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(LogicalTest, LogicalAnd) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_AND);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {true, false, true, false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(LogicalTest, BroadcastLogicalAnd) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_AND);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {true});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 25d2dc2cdd..9fa1c5f100 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -50,7 +50,6 @@ limitations under the License.
// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
// A flattened tensor represents projected bit vectors.
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -60,8 +59,8 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include <farmhash.h>
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 990b3da055..16d67a1a93 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -21,12 +20,16 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -34,6 +37,20 @@ namespace ops {
namespace builtin {
namespace lstm {
+struct OpData {
+ // Which kernel type to use. Full kernel (20 inputs) or basic kernel
+ // (5 inputs).
+ TfLiteLSTMKernelType kernel_type;
+
+ // These fields are only used by full kernel.
+ int activation_state_tensor_index;
+ int cell_state_tensor_index;
+ int scratch_tensor_index;
+};
+
+// For full inputs kernel (20-inputs).
+namespace full {
+
// Input Tensors of size {n_batch, n_input}
constexpr int kInputTensor = 0;
@@ -65,26 +82,27 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// These state tensors are defined as variable tensors, and will be modified by
+// this op.
+constexpr int kInputActivationStateTensor = 18;
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
- return scratch_tensor_index;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
+ auto* op_data = new OpData();
+ op_data->kernel_type = kTfLiteLSTMFullKernel;
+ context->AddTensors(context, /*tensors_to_add=*/7,
+ &op_data->scratch_tensor_index);
+ return op_data;
}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -94,7 +112,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- if (input_to_input_weights) {
+ if (input_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
@@ -114,7 +132,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- if (recurrent_to_input_weights) {
+ if (recurrent_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
n_cell);
@@ -204,7 +222,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- if (projection_weights) {
+ if (projection_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
@@ -212,7 +230,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
- if (projection_bias) {
+ if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
}
@@ -233,15 +251,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
- // Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
@@ -260,68 +282,153 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
- // Get the pointer to output, output_state and cell_state tensors.
+ // Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
+ // The weights are of consistent type, so it suffices to check one.
+ // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
+ TfLiteIntArrayFree(node->temporaries);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
// Create a scratch buffer tensor.
- TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
if (use_cifg) {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 3;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
} else {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Input, Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 4;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
}
return kTfLiteOk;
}
-// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -362,91 +469,303 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
+ // TODO(mirkov): add a check that weights are all uint8s or all floats.
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return lstm_eval::EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
+} // namespace full
- // Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+// For basic kernel (5-inputs).
+namespace basic {
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+enum InputTensor {
+ kInputData = 0,
+ kInputPrevActivation = 1,
+ kInputWeights = 2,
+ kInputBiases = 3,
+ kInputPrevState = 4,
+ kInputNum = 5,
+};
+
+enum OutputTensor {
+ kOutputActivation = 0,
+ kOutputState = 1,
+ kOutputConcatTemp = 2,
+ kOutputActivationTemp = 3,
+ kOutputNum = 4,
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData();
+ op_data->kernel_type = kTfLiteLSTMBasicKernel;
+ // `scratch_tensor_index` is unused in this kernel.
+ op_data->scratch_tensor_index = -1;
+ return op_data;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
+ TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
+ const int num_batches = input->dims->data[0];
+ const int input_depth = input->dims->data[1];
+
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
+ const int activation_depth = prev_activation->dims->data[1];
+ const int total_depth = input_depth + activation_depth;
+
+ TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth);
+ TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth);
+
+ TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth);
+
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(
+ context, activation_out,
+ TfLiteIntArrayCopy(prev_activation->dims)));
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, state_out,
+ TfLiteIntArrayCopy(prev_state->dims)));
+
+ TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
+ concat_temp_size->data[0] = num_batches;
+ concat_temp_size->data[1] = total_depth;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, concat_temp, concat_temp_size));
+ TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
+ activation_temp_size->data[0] = num_batches;
+ activation_temp_size->data[1] = 4 * activation_depth;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
+ activation_temp_size));
+
+ // Set the state tensors as persistent.
+ for (auto index : {kInputPrevActivation, kInputPrevState}) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ tensor->allocation_type = kTfLiteArenaRwPersistent;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ if (input->type == kTfLiteFloat32 &&
+ prev_activation->type == kTfLiteFloat32 &&
+ weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 &&
+ prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 &&
+ activation_out->type == kTfLiteFloat32 &&
+ concat_temp->type == kTfLiteFloat32 &&
+ activation_temp->type == kTfLiteFloat32) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+ optimized_ops::LstmCell(
+ op_params,
+ // Inputs.
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(prev_state), GetTensorData<float>(prev_state),
+ // Outputs.
+ GetTensorShape(state_out), GetTensorData<float>(state_out),
+ GetTensorShape(activation_out), GetTensorData<float>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
+ GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
+ } else if (input->type == kTfLiteUInt8 &&
+ prev_activation->type == kTfLiteUInt8 &&
+ weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
+ prev_state->type == kTfLiteInt16 &&
+ state_out->type == kTfLiteInt16 &&
+ activation_out->type == kTfLiteUInt8 &&
+ concat_temp->type == kTfLiteUInt8 &&
+ activation_temp->type == kTfLiteInt16) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+ int state_scale_log2_rounded;
+ if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
+ context->ReportError(
+ context,
+ "The internal state of a LSTM cell must have a power-of-two scale.");
+ return kTfLiteError;
+ }
+ const int state_integer_bits = 15 + state_scale_log2_rounded;
+ if (state_integer_bits != 4) {
+ context->ReportError(context,
+ "The only case of quantized LstmCell currently "
+ "supported is with StateIntegerBits==4");
+ return kTfLiteError;
+ }
+
+ double real_accum_multiplier = 4096 * bias->params.scale;
+ int32 accum_multiplier;
+ int accum_shift;
+ tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
+ &accum_shift);
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights->params.zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+ optimized_ops::LstmCell<4>(
+ op_params,
+ // Inputs.
+ GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(prev_activation),
+ GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
+ GetTensorData<uint8_t>(weights), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
+ GetTensorData<int16_t>(prev_state),
+ // Outputs.
+ GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
+ GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
+ GetTensorShape(activation_temp),
+ GetTensorData<int16_t>(activation_temp), gemm_context);
} else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ context->ReportError(context,
+ "Unsupported combination of data types for LstmCell");
+ return kTfLiteError;
}
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* output_state_ptr = output_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, output_ptr_batch);
+ // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs
+ // LSTM kernel.
+ memcpy(prev_activation->data.raw, activation_out->data.raw,
+ activation_out->bytes);
+ memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
return kTfLiteOk;
}
+} // namespace basic
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ gemm_support::IncrementUsageCounter(context);
+
+ const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
+ switch (params->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Init(context, buffer, length);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Init(context, buffer, length);
+ }
+}
+void Free(TfLiteContext* context, void* buffer) {
+ gemm_support::DecrementUsageCounter(context);
+
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Prepare(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Prepare(context, node);
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Eval(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Eval(context, node);
+ }
+}
+
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc
new file mode 100644
index 0000000000..20a4e30009
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc
@@ -0,0 +1,912 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+namespace {
+
+// Performs an LSTM batch inference step for input specified by input_ptr_batch.
+// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
+// biases (*_bias_ptr), and buffers (*_scratch), along with additional
+// parameters:
+// - params: various LSTM params including activation, clipping, etc.,
+// - n_batch: size of batch,
+// - n_cell: number of cells (or units),
+// - n_input: the input size,
+// - n_output: the output size.
+//
+// The pointers to the cell and output state and the output are updated.
+//
+// The pointers with the suffix "_batch" point to data aligned in batch_major
+// order, and each step processes batch_size many inputs from input_ptr_batch,
+// and updates batch_size many cell and output states.
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // If auxiliary input is available then compute aux_input_weight * aux_input
+ if (aux_input_ptr_batch != nullptr) {
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+// input_ptr_batch
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+// input_to_input_weights - optional (can be nullptr)
+// input_to_forget_weights
+// input_to_cell_weights
+// input_to_input_weights
+// Quantized recurrent weights of size 'n_cell * n_output':
+// recurrent_to_input_weights - optional
+// recurrent_to_forget_weights
+// recurrent_to_cell_weights
+// recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+// cell_to_input_weights - optional
+// cell_to_cell_weights - optional
+// cell_to_output_weights - optional
+// Quantized projection weights of size 'n_output * n_cell'
+// projection_weights_ptr - optional
+// Weight scales (scalars) for each of the weights above.
+// input_to_input_weights_scale - optional
+// input_to_forget_weights_scale
+// input_to_cell_weights_scale
+// input_to_output_weights_scale
+// recurrent_to_input_weights_scale - optional
+// recurrent_to_forget_weights_scale
+// recurrent_to_cell_weights_scale
+// recurrent_to_output_weights_scale
+// cell_to_input_weights_scale,
+// cell_to_forget_weights_scale,
+// cell_to_output_weights_scale,
+// projection_weights_scale - optional
+// Gate biases of size 'n_cell':
+// input_gate_bias_ptr - optional
+// forget_gate_bias_ptr
+// cell_gate_bias_ptr
+// output_gate_bias_ptr
+//
+// Temporary pre-allocated storage for quantized values:
+// quantized_input_ptr_batch (same size as input_ptr_batch)
+// quantized_output_state_ptr (same size as output_state_ptr)
+// quantized_cell_state_ptr (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+// recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+// output_state_ptr - size 'n_batch * n_output'
+// cell_state_ptr - size 'n_batch * n_cell'
+// output_ptr_batch - size 'n_batch * n_output'
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* scaling_factors, float* product_scaling_factors,
+ float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we
+ // can check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (aux_input_ptr_batch != nullptr &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, n_input,
+ quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+} // namespace
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ float* aux_input_ptr = nullptr;
+ float* aux_input_to_input_weights_ptr = nullptr;
+ float* aux_input_to_forget_weights_ptr = nullptr;
+ float* aux_input_to_cell_weights_ptr = nullptr;
+ float* aux_input_to_output_weights_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+ aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+ aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+ aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+ }
+
+ // Loop through the sequence.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time =
+ output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_aux_input_ptr =
+ (aux_input_quantized == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Auxiliary input and weights.
+ float* aux_input_ptr = nullptr;
+ int8_t* aux_input_to_input_weights_ptr = nullptr;
+ int8_t* aux_input_to_forget_weights_ptr = nullptr;
+ int8_t* aux_input_to_cell_weights_ptr = nullptr;
+ int8_t* aux_input_to_output_weights_ptr = nullptr;
+ float aux_input_to_input_weights_scale = 0.0f;
+ float aux_input_to_forget_weights_scale = 0.0f;
+ float aux_input_to_cell_weights_scale = 0.0f;
+ float aux_input_to_output_weights_scale = 0.0f;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+ aux_input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+ aux_input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+ aux_input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+ aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+ aux_input_to_forget_weights_scale =
+ aux_input_to_forget_weights->params.scale;
+ aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+ aux_input_to_output_weights_scale =
+ aux_input_to_output_weights->params.scale;
+ }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/contrib/lite/kernels/lstm_eval.h
new file mode 100644
index 0000000000..adf8cf0f64
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output);
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index d81220d8d3..e7ddfceb45 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Unit test for TFLite LSTM op.
+//
+// TODO(alanchiao): add unit test with invalid input dimensions for this and its
+// variants.
-#include <iomanip>
#include <memory>
#include <vector>
@@ -35,7 +37,8 @@ class LSTMOpModel : public SingleOpModel {
LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
- const std::vector<std::vector<int>>& input_shapes)
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weight_type = TensorType_FLOAT32)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@@ -45,31 +48,31 @@ class LSTMOpModel : public SingleOpModel {
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
- input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_input_weights_ = AddInput(weight_type);
}
- input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_forget_weights_ = AddInput(weight_type);
+ input_to_cell_weights_ = AddInput(weight_type);
+ input_to_output_weights_ = AddInput(weight_type);
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
- recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_input_weights_ = AddInput(weight_type);
}
- recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_forget_weights_ = AddInput(weight_type);
+ recurrent_to_cell_weights_ = AddInput(weight_type);
+ recurrent_to_output_weights_ = AddInput(weight_type);
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
- cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_input_weights_ = AddInput(weight_type);
}
- cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_forget_weights_ = AddInput(weight_type);
+ cell_to_output_weights_ = AddInput(weight_type);
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
@@ -86,7 +89,7 @@ class LSTMOpModel : public SingleOpModel {
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
- projection_weights_ = AddInput(TensorType_FLOAT32);
+ projection_weights_ = AddInput(weight_type);
if (use_projection_bias) {
projection_bias_ = AddInput(TensorType_FLOAT32);
} else {
@@ -97,14 +100,19 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
+
BuildInterpreter(input_shapes);
}
@@ -176,24 +184,9 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void SetInput(int offset, float* begin, float* end) {
- PopulateTensor(input_, offset, begin, end);
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -203,7 +196,7 @@ class LSTMOpModel : public SingleOpModel {
int num_cells() { return n_cell_; }
int num_batches() { return n_batch_; }
- private:
+ protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
@@ -226,6 +219,8 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
int output_;
int output_state_;
@@ -237,7 +232,174 @@ class LSTMOpModel : public SingleOpModel {
int n_output_;
};
-TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+class HybridLSTMOpModel : public LSTMOpModel {
+ public:
+ HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights, bool use_projection_bias,
+ float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole,
+ use_projection_weights, use_projection_bias, cell_clip,
+ proj_clip, input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> lstm_input_;
+ // LSTM output is stored as num_batch x num_outputs vector.
+ std::vector<std::vector<float>> lstm_golden_output_;
+
+ // Compares output up to tolerance to the result of the lstm given the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ LSTMOpModel* lstm, float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
+ }
+
+ lstm->Invoke();
+
+ const int num_outputs = lstm->num_outputs();
+ std::vector<float> expected;
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ EXPECT_THAT(lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524};
+ input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113, -0.29909778};
+ input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212};
+ input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077,
+ -0.1556896, 0.19487578};
+ input_gate_bias_ = {0., 0., 0., 0.};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_input_weights_ = {
+ -0.0063535, -0.2042388, 0.31454784, -0.35746509,
+ 0.28902304, 0.08183324, -0.16555229, 0.02286911,
+ -0.13566875, 0.03034258, 0.48091322, -0.12528998,
+ 0.24077177, -0.51332325, -0.33502164, 0.10629296};
+
+ recurrent_to_cell_weights_ = {
+ -0.3407414, 0.24443203, -0.2078532, 0.26320225,
+ 0.05695659, -0.00123841, -0.4744786, -0.35869038,
+ -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064};
+
+ recurrent_to_forget_weights_ = {
+ -0.48684245, -0.06655136, 0.42224967, 0.2112639,
+ 0.27654213, 0.20864892, -0.07646349, 0.45877004,
+ 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004};
+
+ recurrent_to_output_weights_ = {
+ 0.43385774, -0.17194885, 0.2718237, 0.09215671,
+ 0.24107647, -0.39835793, 0.18212086, 0.01301402,
+ 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
+ }
+};
+
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -257,10 +419,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
+ {n_cell, n_output}, // recurrent_to_input_weight_tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight_tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight_tensor
+ {n_cell, n_output}, // recurrent_to_output_weight_tensor
{0}, // cell_to_input_weight tensor
{0}, // cell_to_forget_weight tensor
@@ -275,79 +437,129 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
- -0.34550029, 0.04266912, -0.15680569,
- -0.34856534, 0.43890524});
-
- lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
- -0.20583314, 0.44344562, 0.22077113,
- -0.29909778});
-
- lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
- -0.31343272, -0.40032279, 0.44781327,
- 0.01387155, -0.35593212});
-
- lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
- 0.40525138, 0.44272184, 0.03897077, -0.1556896,
- 0.19487578});
-
- lstm.SetInputGateBias({0., 0., 0., 0.});
-
- lstm.SetCellBias({0., 0., 0., 0.});
-
- lstm.SetForgetGateBias({1., 1., 1., 1.});
-
- lstm.SetOutputGateBias({0., 0., 0., 0.});
-
- lstm.SetRecurrentToInputWeights(
- {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
- -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
- -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetRecurrentToCellWeights(
- {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
- -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
- -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetRecurrentToForgetWeights(
- {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
- -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
- 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.SetRecurrentToOutputWeights(
- {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
- 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
- -0.51818722, -0.15390486, 0.0468148, 0.39922136});
-
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
- -0.15358765, -0.03716109, 0.12507336,
- 0.41193449, -0.20860538, -0.15053082,
- 0.09120187, 0.24278517, -0.12222792};
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
- const int input_sequence_size =
- sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
+ HybridLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
+ /*tolerance=*/0.0157651);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
+ 0.05100781, 0.04717243, 0.48944736,
+ -0.38535351, -0.17212132};
- lstm.Invoke();
+ input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698,
+ 0.24407166, 0.33826375};
- float* golden_start = lstm_golden_output + i * lstm.num_outputs();
- float* golden_end = golden_start + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_cell_weights_ = {
+ 0.54066205, -0.32668582, -0.43562764, -0.56094903,
+ 0.42957711, 0.01841056, -0.32764608, -0.33027974,
+ -0.10826075, 0.20675004, 0.19069612, -0.03026325,
+ -0.54532051, 0.33003211, 0.44901288, 0.21193194};
+
+ recurrent_to_forget_weights_ = {
+ -0.13832897, -0.0515101, -0.2359007, -0.16661474,
+ -0.14340827, 0.36986142, 0.23414481, 0.55899,
+ 0.10798943, -0.41174671, 0.17751795, -0.34484994,
+ -0.35874045, -0.11352962, 0.27268326, 0.54058349};
+
+ recurrent_to_output_weights_ = {
+ 0.41613156, 0.42610586, -0.16495961, -0.5663873,
+ 0.30579174, -0.05115908, -0.33941799, 0.23364776,
+ 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987};
+
+ cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
+ 0.31544167};
+ cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
+ -0.77109635};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302}};
}
-}
+};
-TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -385,74 +597,681 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
- 0.04717243, 0.48944736, -0.38535351,
- -0.17212132});
-
- lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
- -0.3633365, -0.22755712, 0.28253698, 0.24407166,
- 0.33826375});
-
- lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
- -0.09426838, -0.44257352, 0.54939759,
- 0.01533556, 0.42751634});
-
- lstm.SetCellBias({0., 0., 0., 0.});
-
- lstm.SetForgetGateBias({1., 1., 1., 1.});
-
- lstm.SetOutputGateBias({0., 0., 0., 0.});
-
- lstm.SetRecurrentToCellWeights(
- {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
- 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
- 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
- 0.21193194});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetRecurrentToForgetWeights(
- {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
- 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
- -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetRecurrentToOutputWeights(
- {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
- -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
- 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.SetCellToForgetWeights(
- {0.47485286, -0.51955009, -0.24458408, 0.31544167});
- lstm.SetCellToOutputWeights(
- {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
- -0.05163646, -0.42312205, -0.01218222,
- 0.24201041, -0.08124574, -0.358325,
- -0.04621704, 0.21641694, -0.06471302};
-
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
- const int input_sequence_size =
- sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
- lstm.Invoke();
+ HybridLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
+}
- float* golden_start = lstm_golden_output + i * lstm.num_outputs();
- float* golden_end = golden_start + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {
+ 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
+
+ input_to_forget_weights_ = {
+ -0.0018401089, -0.004852237, 0.03698424, 0.014181704,
+ 0.028273236, -0.016726194, -0.05249759, -0.10204261,
+ 0.00861066, -0.040979505, -0.009899187, 0.01923892,
+ -0.028177269, -0.08535103, -0.14585495, 0.10662567,
+ -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
+ 0.0814421, -0.12257899, -0.033945758, -0.031303465,
+ 0.045630626, 0.06843887, -0.13492945, -0.012480007,
+ -0.0811829, -0.07224499, -0.09628791, 0.045100946,
+ 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997,
+ 0.052625068, 0.12784666, 0.07077897, 0.025725935,
+ 0.04165009, 0.07241905, 0.018668644, -0.037377294,
+ -0.06277783, -0.08833636, -0.040120605, -0.011405586,
+ -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839,
+ 0.13396506, -0.08402166, -0.01901462, -0.044678304,
+ -0.07720565, 0.014350063, -0.11757958, -0.0652038,
+ -0.08185733, -0.076754324, -0.092614375, 0.10405491,
+ 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447,
+ -0.054523353, 0.02582715, 0.02327355, -0.011857179,
+ -0.0011980024, -0.034641717, -0.026125094, -0.17582615,
+ -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
+ -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
+
+ input_to_cell_weights_ = {
+ -0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
+
+ input_to_output_weights_ = {
+ -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
+
+ input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
+ 0.053110216, -0.06928846, -0.13942584, -0.11816189,
+ 0.19483899, 0.03652339, -0.10250295, 0.036714908,
+ -0.18426876, 0.036065217, 0.21810818, 0.02383196,
+ -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
+
+ forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739};
+
+ cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027};
+
+ output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
+ 0.027195795, 0.35373217, -0.018957434, 0.008907322,
+ -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
+ 0.040952638, 0.3147856, 0.08225149, -0.057416286,
+ -0.14995944, -0.008040261, 0.13208859, 0.029760877};
+
+ recurrent_to_input_weights_ = {
+ -0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447};
+
+ recurrent_to_cell_weights_ = {
+ -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
+
+ recurrent_to_forget_weights_ = {
+ -0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027};
+
+ recurrent_to_output_weights_ = {
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812,
+ -0.01255415, -0.0026479573, -0.08196161, -0.054914974,
+ -0.0046604523, -0.029587349, -0.044576716, -0.07480124,
+ -0.082868785, 0.023254942, 0.027502948, -0.0039728214,
+ -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307,
+ -0.08829125, -0.005139627, -0.08989442, -0.0555066,
+ 0.13596267, -0.025062224, -0.048351806, -0.03850004,
+ 0.07266485, -0.022414139, 0.05940088, 0.075114764,
+ 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916,
+ 0.014416728, 0.043229222, 0.034178585, -0.07530371,
+ 0.035837382, -0.085607, -0.007721233, -0.03287832,
+ -0.043848954, -0.06404588, -0.06632928, -0.073643476,
+ 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091,
+ -0.030063879, 0.008801774, -0.023021035, -0.019558564,
+ 0.05158114, -0.010947698, -0.011825728, 0.0075720972,
+ 0.0699727, -0.0039981045, 0.069350146, 0.08799282,
+ 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699,
+ -0.00924166, 0.0046702605, -0.036598757, -0.08811812,
+ 0.10522024, -0.032441203, 0.008176899, -0.04454919,
+ 0.07058152, 0.0067963637, 0.039206743, 0.03259838,
+ 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724,
+ 0.036879618, 0.043357447, 0.028362012, -0.05908629,
+ 0.0059240665, -0.04995891, -0.019187413, 0.0276265,
+ -0.01628143, 0.0025863599, 0.08800015, 0.035250366,
+ -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454,
+ -0.009660886, 0.019076364, 0.018299393, -0.046004917,
+ 0.08891175, 0.0431396, -0.026327137, -0.051502608,
+ 0.08979574, -0.051670972, 0.04940282, -0.07491107,
+ -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988,
+ -0.035936575, -0.011681591, 0.064818054, 0.0073146066,
+ -0.021745546, -0.043124277, -0.06471268, -0.07053354,
+ -0.029321948, -0.05330136, 0.016933719, -0.053782392,
+ 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434,
+ -0.07924483, 0.06936997, 0.0034815092, -0.007305279,
+ -0.037325785, -0.07251102, -0.033633437, -0.08677009,
+ 0.091591336, -0.14165086, 0.021752775, 0.019683983,
+ 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828,
+ 0.1183656, -0.0010731248, -0.023590032, -0.072285876,
+ -0.0724771, -0.026382286, -0.0014920527, 0.042667855,
+ 0.0018776858, 0.02986552, 0.009814309, 0.0733756,
+ 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798,
+ -0.010036754, 0.02576849, -0.08307328, 0.010112348,
+ 0.042521734, -0.05869831, -0.071689695, 0.03876447,
+ -0.13275425, -0.0352966, -0.023077697, 0.10285965,
+ 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767,
+ -0.08271222, -0.0030240538, -0.016368777, 0.1070414,
+ 0.042672627, 0.013456989, -0.0437609, -0.022309763,
+ 0.11576483, 0.04108048, 0.061026827, -0.0190714,
+ -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893,
+ -0.023771819, -0.01965048, 0.007955471, -0.043740474,
+ 0.03346837, -0.10549954, 0.090567775, 0.042013682,
+ -0.03176985, 0.12569028, -0.02421228, -0.029526481,
+ 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367,
+ -0.06861939, -0.021256343, -0.041093912, -0.06669611,
+ 0.035498552, 0.021757556, -0.09302526, -0.015403468,
+ -0.06614931, -0.051798206, -0.013874718, 0.03630673,
+ 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424,
+ -0.020674974, -0.03944324, -0.008110165, -0.11113267,
+ 0.08484226, 0.043586485, 0.040582247, 0.0968012,
+ -0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
+ 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844,
+ -0.11768019, 0.085926116, -0.08251791, -0.045081906,
+ 0.0948852, 0.068401024, 0.024856757, 0.06978981,
+ -0.057309967, -0.012775832, -0.0032452994, 0.01977615,
+ -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ };
+
+ cell_to_input_weights_ = {
+ 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
+
+ cell_to_forget_weights_ = {
+ -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
+
+ cell_to_output_weights_ = {
+ 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
+
+ projection_weights_ = {
+ -0.009802181, 0.09401916, 0.0717386, -0.13895074,
+ 0.09641832, 0.060420845, 0.08539281, 0.054285463,
+ 0.061395317, 0.034448683, -0.042991187, 0.019801661,
+ -0.16840284, -0.015726732, -0.23041931, -0.024478018,
+ -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914,
+ 0.16169067, 0.22465782, -0.03993472, -0.004017731,
+ 0.08633481, -0.28869787, 0.08682067, 0.17240396,
+ 0.014975425, 0.056431185, 0.031037588, 0.16702051,
+ 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434,
+ 0.014777949, -0.20203483, 0.094781205, 0.19100232,
+ 0.13987629, -0.036132768, -0.06426278, -0.05108664,
+ 0.13221376, 0.009441198, -0.16715929, 0.15859416,
+ -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123,
+ 0.0046453946, 0.050794356, 0.10770313, -0.20790008,
+ -0.07149004, -0.11425117, 0.008225835, -0.035802525,
+ 0.14374903, 0.15262283, 0.048710253, 0.1847461,
+ -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
+ 0.016261552, 0.022461696, 0.12689082, -0.043589946,
+ -0.12035478, -0.08361797, -0.050666027, -0.1248618,
+ -0.1275799, -0.071875185, 0.07377272, 0.09944291,
+ -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444,
+ 0.041546922, -0.20424393, 0.06907816, 0.050412357,
+ 0.00724631, 0.039827548, 0.12449835, 0.10747581,
+ 0.13708383, 0.09134148, -0.12617786, -0.06428341,
+ 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063,
+ 0.022913318, -0.042050496, 0.16842307, -0.060597885,
+ 0.10531834, -0.06411776, -0.07451711, -0.03410368,
+ -0.13393489, 0.06534304, 0.003620307, 0.04490757,
+ 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227,
+ -0.024575593, -0.036445823, 0.07155557, 0.009672501,
+ -0.02328883, 0.009533515, -0.03606021, -0.07421458,
+ -0.028082801, -0.2678904, -0.13221288, 0.18419984,
+ -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874,
+ 0.14494097, -0.12522776, -0.098633975, -0.10766018,
+ -0.08317623, 0.08594209, 0.07749552, 0.039474737,
+ 0.1776665, -0.07409566, -0.0477268, 0.29323658,
+ 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189,
+ 0.10034707, 0.045594677, 0.0635285, -0.0715442,
+ -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
+ -0.009525053, 0.006585689, -0.24567553, -0.09450807,
+ 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363,
+ 0.067035615, 0.19271925, -0.0032889997, -0.043264326,
+ 0.09663576, -0.057112187, -0.10100678, 0.0628376,
+ 0.04447668, 0.017961001, -0.10094388, -0.10190601,
+ 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151,
+ -0.1947263, 0.02251204, 0.11216432, -0.10307853,
+ 0.17351969, -0.039091777, 0.08066188, -0.00561982,
+ 0.12633002, 0.11335965, -0.0088127935, -0.019777594,
+ 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895,
+ -0.07468996, -0.0855457, 0.099339016, -0.07580735,
+ -0.13775392, 0.08434318, 0.08330512, -0.12131499,
+ 0.031935584, 0.09180414, -0.08876437, -0.08049874,
+ 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513,
+ 0.04331237, 0.04299654, -0.036394123, -0.12915532,
+ 0.09793732, 0.07512415, -0.11319543, -0.032502122,
+ 0.15661901, 0.07671967, -0.005491124, -0.19379048,
+ -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725,
+ -0.09334311, 0.15026465, -0.15493552, -0.057762887,
+ -0.11604192, -0.262013, -0.01391798, 0.012185008,
+ 0.11156489, -0.07483202, 0.06693364, -0.26151478,
+ 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075,
+ 0.030635102, 0.010969227, 0.11109743, 0.010919218,
+ 0.027526086, 0.13519906, 0.01891392, -0.046839405,
+ -0.040167913, 0.017953383, -0.09700955, 0.0061885654,
+ -0.07000971, 0.026893595, -0.038844477, 0.14543656};
+
+ lstm_input_ = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
+ 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
+ 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
+ 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
+ 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
+ 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
+ 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
+ };
+
+ lstm_golden_output_ = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
}
-}
+};
-TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -489,588 +1308,90 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights(
- {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
- 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
- -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
- -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
- -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
- -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
- -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
- 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
- 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
- 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
- -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
- 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
- -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
- -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
- -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
- 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
- -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
- -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
- -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
- -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
-
- lstm.SetInputToForgetWeights(
- {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
- -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
- -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
- 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
- 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
- -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
- -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
- 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
- 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
- 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
- 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
- -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
- 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
- -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
- -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
- 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
- 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
- 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
- -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
- 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
-
- lstm.SetInputToCellWeights(
- {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
- -0.043528453, 0.043018587, -0.049152344, -0.12418144,
- -0.078985475, -0.07596889, 0.019484362, -0.11434962,
- -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
- -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
- 0.10665918, -0.032036792, -0.08505916, -0.10843358,
- -0.13002433, -0.036816437, -0.02130134, -0.016518239,
- 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
- -0.10652836, -0.1037554, -0.13056071, -0.03266643,
- -0.033702414, -0.006473424, -0.04611692, 0.014419339,
- -0.025174323, 0.0396852, 0.081777506, 0.06157468,
- 0.10210095, -0.009658194, 0.046511717, 0.03603906,
- 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
- 0.053568836, 0.06408714, 0.12835667, -0.008714329,
- -0.20211966, -0.12093674, 0.029450472, 0.2849013,
- -0.029227901, 0.1164364, -0.08560263, 0.09941786,
- -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
- -0.09720865, -0.11193351, -0.029155117, -0.017936034,
- -0.009768936, -0.04223324, -0.036159635, 0.06505112,
- -0.021742892, -0.023377212, -0.07221364, -0.06430552,
- 0.05453865, 0.091149814, 0.06387331, 0.007518393,
- 0.055960953, 0.069779344, 0.046411168, 0.10509911,
- 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
- 0.056955688, 0.06555285, 0.050801456, -0.009862683,
- 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
-
- lstm.SetInputToOutputWeights(
- {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
- -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
- 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
- -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
- -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
- 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
- -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
- -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
- -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
- -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
- 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
- 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
- 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
- -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
- 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
- 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
- -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
- 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
- -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
- -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
-
- lstm.SetInputGateBias(
- {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
- -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
- -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
- 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
-
- lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
- 0.11098921, 0.15378423, 0.09263801, 0.09790885,
- 0.09508917, 0.061199076, 0.07665568, -0.015443159,
- -0.03499149, 0.046190713, 0.08895977, 0.10899629,
- 0.40694186, 0.06030037, 0.012413437, -0.06108739});
-
- lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
- -0.1483596, -0.10639995, -0.091433935, 0.058573797,
- -0.06809782, -0.07889636, -0.043246906, -0.09829136,
- -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
- 0.016178843, 0.1749513, 0.13975595, 0.92058027});
-
- lstm.SetOutputGateBias(
- {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
- 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
- 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
- -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
-
- lstm.SetRecurrentToInputWeights(
- {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
- -0.11585556, 0.02557986, -0.13446963, -0.035785314,
- -0.01244275, 0.025961924, -0.02337298, -0.044228926,
- -0.055839065, -0.046598054, -0.010546039, -0.06900766,
- 0.027239809, 0.022582639, -0.013296484, -0.05459212,
- 0.08981, -0.045407712, 0.08682226, -0.06867011,
- -0.14390695, -0.02916037, 0.000996957, 0.091420636,
- 0.14283475, -0.07390571, -0.06402044, 0.062524505,
- -0.093129106, 0.04860203, -0.08364217, -0.08119002,
- 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
- -0.13732095, 0.012405723, -0.07551853, 0.06343048,
- 0.12162708, -0.031923793, -0.014335606, 0.01790974,
- -0.10650317, -0.0724401, 0.08554849, -0.05727212,
- 0.06556731, -0.042729504, -0.043227166, 0.011683251,
- -0.013082158, -0.029302018, -0.010899579, -0.062036745,
- -0.022509435, -0.00964907, -0.01567329, 0.04260106,
- -0.07787477, -0.11576462, 0.017356863, 0.048673786,
- -0.017577527, -0.05527947, -0.082487635, -0.040137455,
- -0.10820036, -0.04666372, 0.022746278, -0.07851417,
- 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
- 0.08944216, -0.0685835, 0.010513544, 0.07228705,
- 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
- 0.040414046, -0.1380399, 0.094208956, -0.05722982,
- 0.012092817, -0.04989123, -0.086576, -0.003399834,
- -0.04696032, -0.045747425, 0.10091314, 0.048676282,
- -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
- 0.09504992, 0.041799378, -0.049185462, -0.031518843,
- -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
- -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
- -0.10167381, 0.042500053, -0.01447153, 0.06464186,
- -0.017142897, 0.03312627, 0.009205989, 0.024138335,
- -0.011337001, 0.035530265, -0.010912711, 0.0706555,
- -0.005894094, 0.051841937, -0.1401738, -0.02351249,
- 0.0365468, 0.07590991, 0.08838724, 0.021681072,
- -0.10086113, 0.019608743, -0.06195883, 0.077335775,
- 0.023646897, -0.095322326, 0.02233014, 0.09756986,
- -0.048691444, -0.009579111, 0.07595467, 0.11480546,
- -0.09801813, 0.019894179, 0.08502348, 0.004032281,
- 0.037211012, 0.068537936, -0.048005626, -0.091520436,
- -0.028379958, -0.01556313, 0.06554592, -0.045599163,
- -0.01672207, -0.020169014, -0.011877351, -0.20212261,
- 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
- -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
- 0.015963363, 0.00871737, 0.060130805, 0.028611384,
- 0.10109069, -0.015060172, -0.07894427, 0.06401885,
- 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
- 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
- 0.019899689, 0.006106124, -0.027092824, 0.0786356,
- 0.05052217, -0.058925, -0.011402121, -0.024987547,
- -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
- -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
- -0.033664223, -0.07978348, -0.025200296, -0.017207067,
- -0.058403496, -0.055697463, 0.005798788, 0.12965427,
- -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
- 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
- 0.013806405, -0.017858358, -0.01008298, -0.07700066,
- -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
- 0.062634714, -0.02338735, -0.039547626, -0.02050681,
- 0.03385117, -0.083611414, 0.002862572, -0.09421313,
- 0.058618143, -0.08598433, 0.00972939, 0.023867095,
- -0.053934585, -0.023203006, 0.07452513, -0.048767887,
- -0.07314807, -0.056307215, -0.10433547, -0.06440842,
- 0.04328182, 0.04389765, -0.020006588, -0.09076438,
- -0.11652589, -0.021705797, 0.03345259, -0.010329105,
- -0.025767034, 0.013057034, -0.07316461, -0.10145612,
- 0.06358255, 0.18531723, 0.07759293, 0.12006465,
- 0.1305557, 0.058638252, -0.03393652, 0.09622831,
- -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
- -0.005644518, 0.06857898, -0.12598175, -0.035084512,
- 0.03156317, -0.12794146, -0.031963028, 0.04692781,
- 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
- 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
- 0.08184801, -0.019164372, 0.06791302, 0.034257166,
- -0.10307039, 0.021943003, 0.046745934, 0.0790918,
- -0.0265588, -0.007824208, 0.042546265, -0.00977924,
- -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
- -0.014512694, -0.08251313, 0.08861942, 0.13589665,
- 0.026351685, 0.012641483, 0.07466548, 0.044301085,
- -0.045414884, -0.051112458, 0.03444247, -0.08502782,
- -0.04106223, -0.028126027, 0.028473156, 0.10467447});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
- 0.14811787, 0.10826372, 0.09471067, 0.03987225,
- -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
- 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
- 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
- -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
- -0.06193199, 0.055729095, 0.03736828, 0.020123724,
- 0.061878487, -0.04729229, 0.034919553, -0.07585433,
- -0.04421272, -0.044019096, 0.085488975, 0.04058006,
- -0.06890133, -0.030951202, -0.024628663, -0.07672815,
- 0.034293607, 0.08556707, -0.05293577, -0.033561368,
- -0.04899627, 0.0241671, 0.015736353, -0.095442444,
- -0.029564252, 0.016493602, -0.035026584, 0.022337519,
- -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
- 0.016435321, -0.03263031, -0.09543275, -0.047392778,
- 0.013454138, 0.028934088, 0.01685226, -0.086110644,
- -0.046250615, -0.01847454, 0.047608484, 0.07339695,
- 0.034546845, -0.04881143, 0.009128804, -0.08802852,
- 0.03761666, 0.008096139, -0.014454086, 0.014361001,
- -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
- -0.06509276, -0.006021153, -0.08570962, -0.1451793,
- 0.060212336, 0.055259194, 0.06974018, 0.049454916,
- -0.027794661, -0.08077226, -0.016179763, 0.1169753,
- 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
- -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
- 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
- -0.05695512, 0.047233116, 0.038937137, -0.06542224,
- 0.014429736, -0.09719407, 0.13908425, -0.05379757,
- 0.012321099, 0.082840554, -0.029899208, 0.044217527,
- 0.059855383, 0.07711018, -0.045319796, 0.0948846,
- -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
- -0.13873616, 0.040668588, 0.034832682, -0.015319203,
- -0.018715994, 0.046002675, 0.0599172, -0.043107376,
- 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
- 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
- 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
- 0.052958444, 0.07558703, 0.04817258, 0.044462286,
- -0.015213451, -0.08783778, -0.0561384, -0.003008196,
- 0.047060397, -0.002058388, 0.03429439, -0.018839769,
- 0.024734668, 0.024614193, -0.042046934, 0.09597743,
- -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
- -0.02558259, -0.022822596, -0.023273505, -0.02464396,
- -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
- 0.04383914, -0.046476185, 0.028658995, 0.060410924,
- 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
- 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
- 0.015898481, 0.021362653, -0.030262267, 0.016587038,
- -0.011442813, 0.041154444, -0.007631438, -0.03423484,
- -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
- 0.02318443, -0.041350313, 0.021485701, -0.10906167,
- -0.028218046, -0.00954771, 0.020531068, -0.11995105,
- -0.03672871, 0.024019798, 0.014255957, -0.05221243,
- -0.00661567, -0.04630967, 0.033188973, 0.10107534,
- -0.014027541, 0.030796422, -0.10270911, -0.035999842,
- 0.15443139, 0.07684145, 0.036571592, -0.035900835,
- -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
- -0.03858649, 0.01849943, 0.13872518, 0.01503974,
- 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
- -0.047401894, 0.03100163, -0.041533746, -0.10430945,
- 0.044574402, -0.01425562, -0.024290353, 0.034563623,
- 0.05866852, 0.023947537, -0.09445152, 0.035450947,
- 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
- 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
- 0.03532124, -0.016341697, 0.09685456, -0.016764693,
- 0.051808182, 0.05875331, -0.04536488, 0.001626336,
- -0.028892258, -0.01048663, -0.009793449, -0.017093895,
- 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
- -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
- -0.01769146, 0.040995963, 0.02235177, -0.060430344,
- 0.11475477, -0.023854522, 0.10071741, 0.0686208,
- -0.014250481, 0.034261297, 0.047418304, 0.08562733,
- -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
- 0.04096551, 0.032249358, -0.08355519, -0.026823482,
- 0.056386515, -0.010401743, -0.028396193, 0.08507674,
- 0.014410365, 0.020995233, 0.17040324, 0.11511526,
- 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
- -0.081302024, 0.017264642, -0.009585969, 0.09491168,
- -0.051313367, 0.054532815, -0.014298593, 0.10657464,
- 0.007076659, 0.10964551, 0.0409152, 0.008275321,
- -0.07283536, 0.07937492, 0.04192024, -0.1075027});
-
- lstm.SetRecurrentToCellWeights(
- {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
- 0.055647098, -0.05713207, -0.05626563, 0.005559383,
- 0.03375411, -0.025757805, -0.088049285, 0.06017052,
- -0.06570978, 0.007384076, 0.035123326, -0.07920549,
- 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
- 0.08089997, 0.05143358, 0.038261272, 0.03339287,
- -0.027673481, 0.044746667, 0.028349208, 0.020090483,
- -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
- -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
- -0.10893326, 0.076739706, -0.08509834, -0.027997585,
- 0.037871376, 0.01449768, -0.09002357, -0.06111149,
- -0.046195522, 0.0422062, -0.005683705, -0.1253618,
- -0.012925729, -0.04890792, 0.06985068, 0.037654128,
- 0.03398274, -0.004781977, 0.007032333, -0.031787455,
- 0.010868644, -0.031489216, 0.09525667, 0.013939797,
- 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
- -0.048885044, -0.12722108, 0.035304096, 0.06554885,
- 0.00972396, -0.039238118, -0.05159735, -0.11329045,
- 0.1613692, -0.03750952, 0.06529313, -0.071974665,
- -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
- 0.02786344, -0.014179351, 0.005264273, 0.14376344,
- 0.015983658, 0.03406988, -0.06939408, 0.040699873,
- 0.02111075, 0.09669095, 0.041345075, -0.08316494,
- -0.07684199, -0.045768797, 0.032298047, -0.041805092,
- 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
- -0.024950314, 0.11574242, 0.04508852, -0.04335324,
- 0.06760663, -0.027437469, 0.07216407, 0.06977076,
- -0.05438599, 0.034033038, -0.028602652, 0.05346137,
- 0.043184172, -0.037189785, 0.10420091, 0.00882477,
- -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
- 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
- 0.04361412, -0.007001822, 0.09631092, -0.06702025,
- -0.042049985, -0.035070654, -0.04103342, -0.10273396,
- 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
- -0.008264958, 0.042035464, 0.05891794, 0.029673764,
- 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
- -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
- -0.04043371, -0.017094059, 0.07229206, -0.023670016,
- -0.052195564, -0.025616996, -0.01520939, 0.045104615,
- -0.007376126, 0.003533447, 0.006570588, 0.056037236,
- 0.12436656, 0.051817212, 0.028532185, -0.08686856,
- 0.11868599, 0.07663395, -0.07323171, 0.03463402,
- -0.050708205, -0.04458982, -0.11590894, 0.021273347,
- 0.1251325, -0.15313013, -0.12224372, 0.17228661,
- 0.023029093, 0.086124025, 0.006445803, -0.03496501,
- 0.028332196, 0.04449512, -0.042436164, -0.026587414,
- -0.006041347, -0.09292539, -0.05678812, 0.03897832,
- 0.09465633, 0.008115513, -0.02171956, 0.08304309,
- 0.071401566, 0.019622514, 0.032163795, -0.004167056,
- 0.02295182, 0.030739572, 0.056506045, 0.004612461,
- 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
- -0.1335546, -0.030136576, 0.11584653, -0.014678886,
- 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
- -0.0329582, 0.07922767, 0.029322514, 0.026405897,
- 0.04207835, -0.07073373, 0.063781224, 0.0859677,
- -0.10925287, -0.07011058, 0.048005477, 0.03438226,
- -0.09606514, -0.006669445, -0.043381985, 0.04240257,
- -0.06955775, -0.06769346, 0.043903265, -0.026784198,
- -0.017840602, 0.024307009, -0.040079936, -0.019946516,
- 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
- 0.15978073, 0.10185836, 0.10298046, -0.015476589,
- -0.039390966, -0.072174534, 0.0739445, -0.1211869,
- -0.0347889, -0.07943156, 0.014809798, -0.12412325,
- -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
- -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
- -0.01514876, -0.056505352, -0.012800942, -0.06994386,
- 0.012962922, -0.031234352, 0.07029052, 0.016418684,
- 0.03618972, 0.055686004, -0.08663945, -0.017404709,
- -0.054761406, 0.029065743, 0.052404847, 0.020238016,
- 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
- 0.06262858, 0.009184685, 0.020785125, -0.043904778,
- -0.0270329, -0.03299152, -0.060088247, -0.015162964,
- -0.001828936, 0.12642565, -0.056757294, 0.013586685,
- 0.09232601, -0.035886683, 0.06000002, 0.05229691,
- -0.052580316, -0.082029596, -0.010794592, 0.012947712,
- -0.036429964, -0.085508935, -0.13127148, -0.017744139,
- 0.031502828, 0.036232427, -0.031581745, 0.023051167,
- -0.05325106, -0.03421577, 0.028793324, -0.034633752,
- -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
- -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
-
- lstm.SetRecurrentToOutputWeights({
- 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
- -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
- -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
- -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
- -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
- -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
- -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
- 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
- -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
- 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
- -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
- -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
- 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
- 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
- -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
- 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
- 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
- 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
- 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
- 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
- -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
- 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
- -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
- 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
- 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
- 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
- -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
- -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
- -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
- -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
- -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
- -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
- 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
- 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
- -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
- 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
- -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
- -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
- -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
- 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
- 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
- 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
- -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
- 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
- -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
- -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
- -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
- -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
- 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
- -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
- 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
- -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
- -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
- -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
- -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
- 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
- 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
- -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
- 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
- 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
- -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
- 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
- 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
- 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
- });
-
- lstm.SetCellToInputWeights(
- {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
- -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
- -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
- 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
-
- lstm.SetCellToForgetWeights(
- {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
- -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
- -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
- 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
-
- lstm.SetCellToOutputWeights(
- {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
- -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
- -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
- 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
-
- lstm.SetProjectionWeights(
- {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
- 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
- -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
- -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
- 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
- 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
- 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
- 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
- -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
- -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
- -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
- 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
- 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
- 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
- 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
- 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
- -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
- 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
- -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
- 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
- -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
- -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
- 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
- -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
- 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
- -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
- -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
- 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
- -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
- -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
- -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
- 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
- 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
- -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
- 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
- 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
- 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
- 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
- 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
- -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
- -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
- 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
- -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
- -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
- 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
- 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
- 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
- -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
- -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
- -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
- 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
- -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
- 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
- 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
- -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
- -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
- -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
- 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
- -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
- -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
- -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
- 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
- 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
- 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
-
- static float lstm_input[][20] = {
- {// Batch0: 4 (input_sequence_size) * 5 (n_input)
- 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
- 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
- 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
-
- {// Batch1: 4 (input_sequence_size) * 5 (n_input)
- 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
- 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
- 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
-
- static float lstm_golden_output[][64] = {
- {// Batch0: 4 (input_sequence_size) * 16 (n_output)
- -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
- -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
- -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
- 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
- -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
- -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
- 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
- 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
- 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
- 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
- -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
- -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
- 0.0286833, 0.00824207, 0.0264887, 0.0305169},
- {// Batch1: 4 (input_sequence_size) * 16 (n_output)
- -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
- -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
- 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
- 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
- -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
- -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
- 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
- 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
- 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
- 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
- -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
- -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
- 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
-
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
- const int input_sequence_size =
- sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
- float* batch1_end = batch1_start + lstm.num_inputs();
- lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end);
-
- lstm.Invoke();
-
- float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
- float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
- float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
- float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
- expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+
+ HybridLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 8d676218bd..7cb01465ee 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -86,13 +86,14 @@ struct MinimumOp {
template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) {
- reference_ops::TensorFlowMaximumMinimum<data_type>(
+ reference_ops::MaximumMinimumBroadcast4DSlow(
+ GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1),
- GetTensorDims(op_context.input1),
+ GetTensorShape(op_context.input2),
GetTensorData<data_type>(op_context.input2),
- GetTensorDims(op_context.input2),
+ GetTensorShape(op_context.output),
GetTensorData<data_type>(op_context.output),
- GetTensorDims(op_context.output), op_type::template op<data_type>);
+ op_type::template op<data_type>);
}
template <KernelType kernel_type, typename OpType>
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc
index 0752aa1804..fd4d5367c5 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc
@@ -126,10 +126,10 @@ TEST(MaximumOpTest, FloatWithBroadcastTest) {
TEST(MaximumOpTest, Int32WithBroadcastTest) {
std::initializer_list<int32_t> data1 = {1, 0, -1, -2, 3, 11};
std::initializer_list<int32_t> data2 = {2};
- TestModel<int32>(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}},
+ TestModel<int32_t>(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}},
{TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}},
data1, data2, {2, 2, 2, 2, 3, 11});
- TestModel<int32>(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}},
+ TestModel<int32_t>(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}},
{TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}},
data1, data2, {1, 0, -1, -2, 2, 2});
}
diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc
deleted file mode 100644
index 03e5db24de..0000000000
--- a/tensorflow/contrib/lite/kernels/mean.cc
+++ /dev/null
@@ -1,271 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <string.h>
-#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
-#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
-#include "tensorflow/contrib/lite/kernels/kernel_util.h"
-#include "tensorflow/contrib/lite/kernels/op_macros.h"
-
-namespace tflite {
-namespace ops {
-namespace builtin {
-namespace mean {
-
-// This file has reference implementation of Mean.
-enum KernelType {
- kReference,
-};
-
-struct MeanContext {
- MeanContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLiteMeanParams*>(node->builtin_data);
- input = GetInput(context, node, 0);
- axis = GetInput(context, node, 1);
- output = GetOutput(context, node, 0);
- }
- TfLiteMeanParams* params;
- const TfLiteTensor* input;
- const TfLiteTensor* axis;
- TfLiteTensor* output;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- // Creates two temp tensors to store index and axis for internal
- // implementation only.
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 3, scratch_tensor_index);
- return scratch_tensor_index;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
-}
-
-// Resizes the temp tensor that stores resolved axis.
-TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context,
- TfLiteTensor* resolved_axis) {
- TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
- axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
- return context->ResizeTensor(context, resolved_axis, axis_size);
-}
-
-// Resizes the temp tensor that stores temp sum of reduced elements.
-TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context,
- TfLiteTensor* temp_sum) {
- TfLiteIntArray* size = TfLiteIntArrayCreate(1);
- size->data[0] = static_cast<int>(NumElements(op_context->output));
- return context->ResizeTensor(context, temp_sum, size);
-}
-
-// Resizes output array based on the input size and resolved axis.
-TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
- MeanContext* op_context) {
- size_t num_axis = NumElements(op_context->axis);
- const TfLiteIntArray* input_dims = op_context->input->dims;
- int input_num_dims = NumDimensions(op_context->input);
- const int* axis = GetTensorData<int>(op_context->axis);
- if (op_context->params->keep_dims) {
- TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
- for (int idx = 0; idx < input_num_dims; ++idx) {
- bool is_axis = false;
- for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
- if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
- is_axis = true;
- break;
- }
- }
- if (is_axis) {
- output_dims->data[idx] = 1;
- } else {
- output_dims->data[idx] = input_dims->data[idx];
- }
- }
- return context->ResizeTensor(context, op_context->output, output_dims);
- } else {
- // Calculates size of reducing axis.
- int num_reduce_axis = num_axis;
- for (int i = 0; i < num_axis; ++i) {
- int current = axis[i];
- if (current < 0) {
- current += input_num_dims;
- }
- TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims);
- for (int j = 0; j < i; ++j) {
- int previous = axis[j];
- if (previous < 0) {
- previous += input_num_dims;
- }
- if (current == previous) {
- --num_reduce_axis;
- break;
- }
- }
- }
- // Determines output dimensions.
- TfLiteIntArray* output_dims =
- TfLiteIntArrayCreate(input_num_dims - num_reduce_axis);
- int num_skip_axis = 0;
- for (int idx = 0; idx < input_num_dims; ++idx) {
- bool is_axis = false;
- for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
- if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
- ++num_skip_axis;
- is_axis = true;
- break;
- }
- }
- if (!is_axis) {
- output_dims->data[idx - num_skip_axis] = input_dims->data[idx];
- }
- }
- return context->ResizeTensor(context, op_context->output, output_dims);
- }
-}
-
-// Initializes temp tensors to store index and resolved axis.
-TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
- MeanContext* op_context) {
- // Creates a temp index to iterate through input data.
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
- TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(3);
- node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
- scratch_tensor->type = kTfLiteInt32;
- scratch_tensor->allocation_type = kTfLiteArenaRw;
- TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
- index_size->data[0] = NumDimensions(op_context->input);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, scratch_tensor, index_size));
-
- // Creates a temp tensor to store resolved axis given input data.
- node->temporaries->data[1] = *scratch_tensor_index + 1;
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- resolved_axis->type = kTfLiteInt32;
- // Creates a temp tensor to store temp sums when calculating mean.
- node->temporaries->data[2] = *scratch_tensor_index + 2;
- TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
- switch (op_context->input->type) {
- case kTfLiteFloat32:
- temp_sum->type = kTfLiteFloat32;
- break;
- case kTfLiteInt32:
- temp_sum->type = kTfLiteInt64;
- break;
- case kTfLiteInt64:
- temp_sum->type = kTfLiteInt64;
- break;
- case kTfLiteUInt8:
- temp_sum->type = kTfLiteInt32;
- break;
- default:
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
- MeanContext op_context(context, node);
- TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
-
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
- // Leaves work to Eval if axis is not constant; else resizes output.
- if (!IsConstantTensor(op_context.axis)) {
- SetTensorToDynamic(op_context.output);
- SetTensorToDynamic(resolved_axis);
- SetTensorToDynamic(temp_sum);
- return kTfLiteOk;
- }
- resolved_axis->allocation_type = kTfLiteArenaRw;
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- temp_sum->allocation_type = kTfLiteArenaRw;
- return ResizeTempSum(context, &op_context, temp_sum);
-}
-
-template <KernelType kernel_type>
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- MeanContext op_context(context, node);
- int num_axis = static_cast<int>(NumElements(op_context.axis));
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
- }
-
-#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \
- kernel_type::Mean<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis), \
- GetTensorData<temp_data_type>(temp_sum))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int, int64_t));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
- break;
- default:
- return kTfLiteError;
- }
- }
-#undef TF_LITE_MEAN
- return kTfLiteOk;
-}
-} // namespace mean
-
-TfLiteRegistration* Register_MEAN_REF() {
- static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare,
- mean::Eval<mean::kReference>};
- return &r;
-}
-
-// TODO(kanlig): add optimized implementation of Mean.
-TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
-
-} // namespace builtin
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc
deleted file mode 100644
index 79c9957f76..0000000000
--- a/tensorflow/contrib/lite/kernels/mean_test.cc
+++ /dev/null
@@ -1,219 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/kernels/test_util.h"
-#include "tensorflow/contrib/lite/model.h"
-
-namespace tflite {
-namespace {
-
-using ::testing::ElementsAreArray;
-
-class BaseMeanOpModel : public SingleOpModel {
- public:
- void SetAxis(std::initializer_list<int> data) { PopulateTensor(axis_, data); }
-
- template <class T>
- void SetInput(std::initializer_list<T> data) {
- PopulateTensor(input_, data);
- }
-
- template <class T>
- std::vector<T> GetOutput() {
- return ExtractVector<T>(output_);
- }
-
- std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
- }
-
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- int Input() { return input_; }
-
- protected:
- int input_;
- int axis_;
- int output_;
-};
-
-// Model for the tests case where axis is a const tensor.
-class MeanOpConstModel : public BaseMeanOpModel {
- public:
- MeanOpConstModel(const TensorData& input, const TensorData& output,
- std::initializer_list<int> axis_shape,
- std::initializer_list<int> axis, bool keep_dims) {
- input_ = AddInput(input);
- axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
- output_ = AddOutput(output);
- SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
- CreateMeanOptions(builder_, keep_dims).Union());
- BuildInterpreter({GetShape(input_)});
- }
-};
-
-// Model for the tests case where axis is a dynamic tensor.
-class MeanOpDynamicModel : public BaseMeanOpModel {
- public:
- MeanOpDynamicModel(const TensorData& input, const TensorData& output,
- const TensorData& axis, bool keep_dims) {
- input_ = AddInput(input);
- axis_ = AddInput(axis);
- output_ = AddOutput(output);
- SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
- CreateMeanOptions(builder_, keep_dims).Union());
- BuildInterpreter({GetShape(input_)});
- }
-};
-
-TEST(ConstFloatMeanOpTest, NotKeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
- MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
- {4}, {1, 0, -3, -3}, false);
- m.SetInput(data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
- EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
-}
-
-TEST(ConstFloatMeanOpTest, KeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
- MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
- {2}, {0, 2}, true);
- m.SetInput(data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
- EXPECT_THAT(m.GetOutput<float>(),
- ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
-}
-
-TEST(DynamicFloatMeanOpTest, NotKeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
- MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
- {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
- false);
- std::initializer_list<int> axis = {1, 0, -3, -3};
- m.SetAxis(axis);
- m.SetInput(data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
- EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
-}
-
-TEST(DynamicFloatMeanOpTest, KeepDims) {
- std::initializer_list<float> data = {
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
- 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
- MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
- {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}},
- true);
- std::initializer_list<int> axis = {0, 2};
- m.SetAxis(axis);
- m.SetInput(data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
- EXPECT_THAT(m.GetOutput<float>(),
- ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
-}
-
-TEST(DynamicFloatMeanOpTest, Scale) {
- std::initializer_list<float> data = {9.527};
- MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
- {TensorType_INT32, {1}}, true);
- std::initializer_list<int> axis = {0};
- m.SetAxis(axis);
- m.SetInput(data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
- EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
-}
-
-// for quantized Add, the error shouldn't exceed step
-float GetTolerance(int min, int max) { return (max - min) / 255.0; }
-
-TEST(ConstUint8MeanOpTest, NotKeepDims) {
- float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
- std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
- MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
- {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
- m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
- EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {0.4, 0.4}, kQuantizedTolerance)));
-}
-
-TEST(ConstUint8MeanOpTest, KeepDims) {
- float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
- std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
- MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
- {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
- m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
- EXPECT_THAT(
- m.GetDequantizedOutput(),
- ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
-}
-
-TEST(DynamicUint8MeanOpTest, NotKeepDims) {
- float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
- std::initializer_list<float> data = {1.3, -4.8, -3.6, 0.24};
- MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
- {TensorType_UINT8, {2}, -5.0, 2.0},
- {TensorType_INT32, {1}}, false);
- std::initializer_list<int> axis = {1};
- m.SetAxis(axis);
- m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
- EXPECT_THAT(
- m.GetDequantizedOutput(),
- ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance)));
-}
-
-TEST(DynamicUint8MeanOpTest, KeepDims) {
- float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
- std::initializer_list<float> data = {11.14, -0.14, 7.423, 0.879};
- MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
- {TensorType_UINT8, {2}, -10.0, 12.0},
- {TensorType_INT32, {1}}, true);
- std::initializer_list<int> axis = {0};
- m.SetAxis(axis);
- m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
- EXPECT_THAT(
- m.GetDequantizedOutput(),
- ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
-}
-
-} // namespace
-} // namespace tflite
-
-int main(int argc, char** argv) {
- ::tflite::LogToStderr();
- ::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 3f5bc4d68a..5153ce5634 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc
index 0291ca8c1c..fe69223222 100644
--- a/tensorflow/contrib/lite/kernels/mfcc_test.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 6c4c3a1edc..e0aac8a842 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
+
+ // Parameters used in the quantized paths where the output is 8bit
+ int32 output_activation_min;
+ int32 output_activation_max;
+
+ // Parameters used in all quantized paths
+ int32_t output_multiplier;
+ int output_shift;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
@@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
- output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);
@@ -74,74 +82,136 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size = TfLiteIntArrayCopy(input1->dims);
}
+ if (output->type == kTfLiteUInt8) {
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ double real_multiplier =
+ input1->params.scale * input2->params.scale / output->params.scale;
+ QuantizeMultiplierSmallerThanOneExp(
+ real_multiplier, &data->output_multiplier, &data->output_shift);
+ }
+
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_MUL(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
+ } else {
+ TF_LITE_MUL(reference_ops, Mul, int32_t);
+ }
} else {
- TF_LITE_MUL(reference_ops, Mul);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
+ } else {
+ TF_LITE_MUL(reference_ops, Mul, float);
+ }
} else {
- TF_LITE_MUL(optimized_ops, Mul);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul, float);
+ }
}
}
#undef TF_LITE_MUL
}
template <KernelType kernel_type>
-void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- auto input1_offset = -input1->params.zero_point;
- auto input2_offset = -input2->params.zero_point;
- auto output_offset = output->params.zero_point;
-
- int32_t output_multiplier;
- int output_shift;
-
- double real_multiplier =
- input1->params.scale * input2->params.scale / output->params.scale;
- QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
- &output_shift);
-
- int32 output_activation_min, output_activation_max;
- CalculateActivationRangeUint8(params->activation, output,
- &output_activation_min, &output_activation_max);
-
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, const OpData* data,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+ if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
+ output->type == kTfLiteUInt8) {
#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, output_offset, \
- output_multiplier, output_shift, output_activation_min, \
- output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
- // The quantized version of Mul doesn't support activations, so we
- // always use BroadcastMul.
- if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.input1_offset = -input1->params.zero_point; \
+ op_params.input2_offset = -input2->params.zero_point; \
+ op_params.output_offset = output->params.zero_point; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+
+ // The quantized version of Mul doesn't support activations, so we
+ // always use BroadcastMul.
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow);
+ } else {
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteInt16) {
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.output_offset = output->params.zero_point; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
} else {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ context->ReportError(
+ context, "Unsupported combination of input and output types in Mul.");
+ return kTfLiteError;
}
-#undef TF_LITE_MUL
+ return kTfLiteOk;
}
template <KernelType kernel_type>
@@ -153,14 +223,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
- } else if (output->type == kTfLiteUInt8) {
- EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
- output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(
+ context, EvalQuantized<kernel_type>(context, node, params, data, input1,
+ input2, output));
} else {
context->ReportError(context,
- "Mul only supports FLOAT32 and quantized UINT8 now.");
+ "Mul only supports FLOAT32, INT32 and quantized UINT8 "
+ "and INT16 now, got %d.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
index f1a30f8263..2807550a6b 100644
--- a/tensorflow/contrib/lite/kernels/mul_test.cc
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -52,12 +52,22 @@ class FloatMulOpModel : public BaseMulOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerMulOpModel : public BaseMulOpModel {
+ public:
+ using BaseMulOpModel::BaseMulOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
// For quantized Mul, the error shouldn't exceed (2*step + step^2).
// The param min=-1.0 & max=1.0 is used in the following tests.
// The tolerance value is ~0.0157.
const float kQuantizedStep = 2.0 / 255.0;
const float kQuantizedTolerance =
2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep;
+const float kQuantizedStepInt16 = 2.0 / 32767.0;
+const float kQuantizedToleranceInt16 =
+ 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16;
class QuantizedMulOpModel : public BaseMulOpModel {
public:
@@ -67,6 +77,11 @@ class QuantizedMulOpModel : public BaseMulOpModel {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
+
+ std::vector<float> GetDequantizedOutputInt16() {
+ return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
};
TEST(FloatMulOpTest, NoActivation) {
@@ -125,6 +140,57 @@ TEST(FloatMulOpTest, WithBroadcast) {
}
}
+TEST(IntegerMulOpTest, NoActivation) {
+ IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40}));
+}
+
+TEST(IntegerMulOpTest, ActivationRELU_N1_TO_1) {
+ IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1}));
+}
+
+TEST(IntegerMulOpTest, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerMulOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5, 11, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40, 121, 20}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerMulOpTest, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerMulOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-20, 2, 7, 8, 11, 20})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedMulOpTest, NoActivation) {
QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
@@ -138,6 +204,38 @@ TEST(QuantizedMulOpTest, NoActivation) {
kQuantizedTolerance)));
}
+TEST(QuantizedMulOpTest, NoActivationInt16) {
+ const float kMin = -1.f;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {}, kMin, kMax},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutputInt16(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedToleranceInt16)));
+}
+
+TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) {
+ const float kMinInt16 = -1.f;
+ const float kMaxInt16 = 32767.f / 32768.f;
+ const float kMinUint8 = -1.f;
+ const float kMaxUint8 = 127.f / 128.f;
+ QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
+ {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
+ {TensorType_UINT8, {}, kMinUint8, kMaxUint8},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedTolerance)));
+}
+
// for quantized Mul, the error shouldn't exceed 2*step
float GetTolerance(int min, int max) {
float kQuantizedStep = (max - min) / 255.0;
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index b8b53f3402..0ddd0644f5 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
namespace tflite {
@@ -59,7 +59,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
default:
context->ReportError(
- context, "Neg only currently supports int64, int32, and float32.",
+ context,
+ "Neg only currently supports int64, int32, and float32, got %d.",
input->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc
index 3c95ac8cc2..3d3594c60b 100644
--- a/tensorflow/contrib/lite/kernels/neg_test.cc
+++ b/tensorflow/contrib/lite/kernels/neg_test.cc
@@ -58,9 +58,9 @@ TEST(NegOpModel, NegFloat) {
TEST(NegOpModel, NegInt32) {
NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}});
- m.SetInput<int32>({-2, -1, 0, 1, 2, 3});
+ m.SetInput<int32_t>({-2, -1, 0, 1, 2, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({2, 1, 0, -1, -2, -3}));
+ EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({2, 1, 0, -1, -2, -3}));
}
TEST(NegOpModel, NegInt64) {
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
new file mode 100644
index 0000000000..910aed6f14
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -0,0 +1,199 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace one_hot {
+
+constexpr int kIndicesTensor = 0;
+constexpr int kDepthTensor = 1;
+constexpr int kOnValueTensor = 2;
+constexpr int kOffValueTensor = 3;
+constexpr int kOutputTensor = 0;
+
+// Convenience utility for destructuring a node into the appropriate tensors and
+// data for the op. Note that this destructuring is quite cheap, so we can avoid
+// allocating op-specific, persistent data on the heap.
+struct OneHotContext {
+ OneHotContext(TfLiteContext* context, TfLiteNode* node) {
+ indices = GetInput(context, node, kIndicesTensor);
+ depth = GetInput(context, node, kDepthTensor);
+ on_value = GetInput(context, node, kOnValueTensor);
+ off_value = GetInput(context, node, kOffValueTensor);
+ output = GetOutput(context, node, kOutputTensor);
+
+ const auto* params =
+ reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
+ const int indices_dims = indices->dims->size;
+ axis = (params->axis == -1) ? indices_dims : params->axis;
+ output_dims = indices_dims + 1;
+ dtype = on_value->type;
+ }
+
+ const TfLiteTensor* indices;
+ const TfLiteTensor* depth;
+ const TfLiteTensor* on_value;
+ const TfLiteTensor* off_value;
+ TfLiteTensor* output;
+ int axis;
+ int output_dims;
+ TfLiteType dtype;
+};
+
+template <typename T, typename TI>
+void OneHotComputeImpl(const OneHotContext& op_context) {
+ // prefix_dim_size == # of elements before the axis
+ // depth == # of elements per axis
+ // suffix_dim_size == # of elements after the axis
+ int prefix_dim_size = 1;
+ for (int i = 0; i < op_context.axis; ++i) {
+ prefix_dim_size *= op_context.indices->dims->data[i];
+ }
+ const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
+ const int depth = *op_context.depth->data.i32;
+
+ const T on_value = *GetTensorData<T>(op_context.on_value);
+ const T off_value = *GetTensorData<T>(op_context.off_value);
+
+ // View the indices as a matrix of size:
+ // prefix_dim_size x suffix_dim_size
+ // View the output as a matrix of size:
+ // prefix_dim_size x depth x suffix_dim_size
+ // Then the output is:
+ // output(i, j, k) == (indices(i, k) == j) ? on : off
+ T* output = GetTensorData<T>(op_context.output);
+ const TI* indices = GetTensorData<TI>(op_context.indices);
+ for (int i = 0; i < prefix_dim_size; ++i) {
+ for (int j = 0; j < depth; ++j) {
+ for (int k = 0; k < suffix_dim_size; ++k, ++output) {
+ *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
+ ? on_value
+ : off_value;
+ }
+ }
+ }
+}
+
+template <typename T>
+void OneHotCompute(const OneHotContext& op_context) {
+ if (op_context.indices->type == kTfLiteInt64) {
+ OneHotComputeImpl<T, int64_t>(op_context);
+ } else {
+ OneHotComputeImpl<T, int>(op_context);
+ }
+}
+
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const OneHotContext& op_context) {
+ TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
+ for (int i = 0; i < op_context.output_dims; ++i) {
+ if (i < op_context.axis) {
+ output_size->data[i] = op_context.indices->dims->data[i];
+ } else if (i == op_context.axis) {
+ output_size->data[i] = *op_context.depth->data.i32;
+ } else {
+ output_size->data[i] = op_context.indices->dims->data[i - 1];
+ }
+ }
+ return context->ResizeTensor(context, op_context.output, output_size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OneHotContext op_context{context, node};
+ switch (op_context.dtype) {
+ // TODO(b/111744875): Support uint8 and quantization.
+ case kTfLiteFloat32:
+ case kTfLiteInt16:
+ case kTfLiteInt32:
+ case kTfLiteInt64:
+ case kTfLiteBool:
+ op_context.output->type = op_context.dtype;
+ break;
+ default:
+ context->ReportError(context, "Unknown output data type: %d",
+ op_context.dtype);
+ return kTfLiteError;
+ }
+
+ TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
+ op_context.indices->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, op_context.axis >= 0 &&
+ op_context.axis < op_context.output_dims);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
+ TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
+
+ if (!IsConstantTensor(op_context.depth)) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+
+ return ResizeOutputTensor(context, op_context);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OneHotContext op_context{context, node};
+
+ if (IsDynamicTensor(op_context.output)) {
+ ResizeOutputTensor(context, op_context);
+ }
+
+ switch (op_context.output->type) {
+ case kTfLiteFloat32:
+ OneHotCompute<float>(op_context);
+ break;
+ case kTfLiteInt32:
+ OneHotCompute<int>(op_context);
+ break;
+ case kTfLiteInt64:
+ OneHotCompute<int64_t>(op_context);
+ break;
+ case kTfLiteBool:
+ OneHotCompute<bool>(op_context);
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace one_hot
+
+TfLiteRegistration* Register_ONE_HOT() {
+ static TfLiteRegistration r = {
+ nullptr,
+ nullptr,
+ one_hot::Prepare,
+ one_hot::Eval,
+ };
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/contrib/lite/kernels/one_hot_test.cc
new file mode 100644
index 0000000000..6b604ec7a7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/one_hot_test.cc
@@ -0,0 +1,182 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class OneHotOpModel : public SingleOpModel {
+ public:
+ OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
+ TensorType dtype, int axis = -1, T on_value = 1,
+ T off_value = 0, TensorType indices_type = TensorType_INT32) {
+ indices_ = AddInput(indices_type);
+ int depth = AddInput(TensorType_INT32);
+ int on = AddInput(dtype);
+ int off = AddInput(dtype);
+ output_ = AddOutput(dtype);
+ SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
+ CreateOneHotOptions(builder_, axis).Union());
+ BuildInterpreter({input_shape});
+
+ PopulateTensor<int>(depth, {depth_value});
+ PopulateTensor<T>(on, {on_value});
+ PopulateTensor<T>(off, {off_value});
+ }
+
+ template <typename TI>
+ void SetIndices(std::initializer_list<TI> data) {
+ PopulateTensor<TI>(indices_, data);
+ }
+
+ TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
+
+ int32_t GetOutputSize() { return GetTensorSize(output_); }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int indices_;
+ int output_;
+};
+
+TEST(OneHotOpTest, BasicFloat) {
+ const int depth = 3;
+ OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
+}
+
+TEST(OneHotOpTest, BasicInt) {
+ const int depth = 3;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
+}
+
+TEST(OneHotOpTest, BasicBool) {
+ const int depth = 3;
+ OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({true, false, false, false, true, false, false,
+ false, true}));
+}
+
+TEST(OneHotOpTest, SmallDepth) {
+ const int depth = 1;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
+}
+
+TEST(OneHotOpTest, BigDepth) {
+ const int depth = 4;
+ OneHotOpModel<int> model({2}, depth, TensorType_INT32);
+ model.SetIndices({0, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
+}
+
+TEST(OneHotOpTest, OnOffValues) {
+ const int depth = 3;
+ const int axis = -1;
+ const int on = 5;
+ const int off = 0;
+ OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
+ model.SetIndices({0, 2, -1, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
+}
+
+TEST(OneHotOpTest, ZeroAxis) {
+ const int depth = 3;
+ const int axis = 0;
+ const int on = 5;
+ const int off = 0;
+ OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
+ model.SetIndices({0, 2, -1, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
+}
+
+TEST(OneHotOpTest, MultiDimensionalIndices) {
+ const int depth = 3;
+ const int axis = -1;
+ const float on = 2;
+ const float off = 0;
+ OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
+ model.SetIndices({0, 2, 1, -1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
+}
+
+TEST(OneHotOpTest, Int64Indices) {
+ const int depth = 3;
+ const int axis = -1;
+ const int on = 1;
+ const int off = 0;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
+ TensorType_INT64);
+ std::initializer_list<int64_t> indices = {0, 1, 2};
+ model.SetIndices(indices);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
index 7568eaa88e..11e814daee 100644
--- a/tensorflow/contrib/lite/kernels/op_macros.h
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -12,23 +12,61 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
+// If we're on a platform without standard IO functions, fall back to a
+// non-portable function.
+#ifdef TF_LITE_MCU_DEBUG_LOG
+
+// This header is pulled in from the support library at
+// https://github.com/google/stm32_bare_lib
+#include <debug_log.h>
+
+#define DEBUG_LOG(x) \
+ do { \
+ DebugLog(x); \
+ } while (0)
+
+inline void InfiniteLoop() {
+ DEBUG_LOG("HALTED\n");
+ while (1) {
+ }
+}
+#define TFLITE_ASSERT_FALSE InfiniteLoop();
+#define TFLITE_ABORT InfiniteLoop();
+
+#else // TF_LITE_MCU_DEBUG_LOG
+
+#include <cassert>
#include <cstdio>
+#include <cstdlib>
-#define TF_LITE_FATAL(msg) \
- do { \
- fprintf(stderr, "%s\n", (msg)); \
- exit(1); \
+#define DEBUG_LOG(x) \
+ do { \
+ fprintf(stderr, "%s", (x)); \
} while (0)
+
+#define TFLITE_ASSERT_FALSE assert(false)
+#define TFLITE_ABORT abort()
+
+#endif // TF_LITE_MCU_DEBUG_LOG
+
+#define TF_LITE_FATAL(msg) \
+ do { \
+ DEBUG_LOG(msg); \
+ DEBUG_LOG("\nFATAL\n"); \
+ TFLITE_ABORT; \
+ } while (0)
+
#define TF_LITE_ASSERT(x) \
do { \
if (!(x)) TF_LITE_FATAL(#x); \
} while (0)
+
#define TF_LITE_ASSERT_EQ(x, y) \
do { \
if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
} while (0)
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index bcad58406a..90a915bb02 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -95,8 +95,12 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@@ -174,22 +178,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -228,10 +216,10 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -316,10 +304,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
lstm.SetCellToOutputWeights(
{-0.17135078, 0.82760304, 0.85573703, -0.77109635});
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
// Verify the model by unpacking it.
lstm.Verify();
}
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
new file mode 100644
index 0000000000..c368582ef7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -0,0 +1,135 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pack {
+namespace {
+
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input0 = GetInput(context, node, 0);
+ TF_LITE_ENSURE(context, NumDimensions(input0) < 4);
+ TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
+ input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16) {
+ context->ReportError(context,
+ "Currently pack only supports "
+ "float32/uint8/int16/int32.");
+ return kTfLiteError;
+ }
+ // Make sure all inputs have the same shape and type.
+ for (int i = 1; i < data->values_count; ++i) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
+ TF_LITE_ENSURE_EQ(context, input0->type, input->type);
+ }
+
+ // Resize output. rank R will become rank R + 1
+ const int dimension_size = NumDimensions(input0) + 1;
+ const TfLiteIntArray* input_shape = input0->dims;
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
+ int i = 0;
+ for (int index = 0; index < dimension_size; ++index) {
+ if (index == data->axis) {
+ output_shape->data[index] = data->values_count;
+ } else {
+ output_shape->data[index] = input_shape->data[i++];
+ }
+ }
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, output->type, input0->type);
+
+ // Guarantee input/output quantization params match as we do not support
+ // packing quantized tensors.
+ for (int i = 0; i < data->values_count; i++) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point,
+ output->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
+ }
+
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+template <typename T>
+void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
+ int values_count, int axis) {
+ VectorOfTensors<T> all_inputs(*context, *node->inputs);
+ tflite::PackParams op_params;
+ op_params.axis = axis;
+ op_params.inputs_count = values_count;
+
+ reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
+ GetTensorShape(output), GetTensorData<T>(output));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ switch (output->type) {
+ case kTfLiteFloat32: {
+ PackImpl<float>(context, node, output, data->values_count, data->axis);
+ break;
+ }
+ case kTfLiteUInt8: {
+ PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports "
+ "float32/uint8/int32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace pack
+
+TfLiteRegistration* Register_PACK() {
+ static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc
new file mode 100644
index 0000000000..c70dbd2764
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pack_test.cc
@@ -0,0 +1,154 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class PackOpModel : public SingleOpModel {
+ public:
+ PackOpModel(const TensorData& input_template, int axis, int values_count) {
+ std::vector<std::vector<int>> all_input_shapes;
+ for (int i = 0; i < values_count; ++i) {
+ all_input_shapes.push_back(input_template.shape);
+ AddInput(input_template);
+ }
+ output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
+ input_template.max});
+ SetBuiltinOp(BuiltinOperator_PACK, BuiltinOptions_PackOptions,
+ CreatePackOptions(builder_, values_count, axis).Union());
+ BuildInterpreter(all_input_shapes);
+ }
+
+ void SetInput(int index, std::initializer_list<T> data) {
+ PopulateTensor(index, data);
+ }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int output_;
+};
+
+// float32 tests.
+TEST(PackOpTest, FloatThreeInputs) {
+ PackOpModel<float> model({TensorType_FLOAT32, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, FloatThreeInputsDifferentAxis) {
+ PackOpModel<float> model({TensorType_FLOAT32, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, FloatMultilDimensions) {
+ PackOpModel<float> model({TensorType_FLOAT32, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
+// int32 tests.
+TEST(PackOpTest, Int32ThreeInputs) {
+ PackOpModel<int32_t> model({TensorType_INT32, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
+ PackOpModel<int32_t> model({TensorType_INT32, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, Int32MultilDimensions) {
+ PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
+// uint8
+TEST(PackOpTest, Uint8ThreeInputs) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, Uint8MultilDimensions) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index ecac2dd5e3..0d939405f6 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -92,8 +92,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
op_context.constant_values->type);
}
- // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
- TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
+ // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
+ TF_LITE_ENSURE(context, op_context.dims <= 4);
// Exit early if paddings is a non-const tensor. Set output tensor to
// dynamic so output size can be determined in Eval.
@@ -128,18 +128,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(nupurgarg): Change kernel implementation to use padding arrays in
// forward order (depth, width, height, batch).
// Build paddings in order of int[] = {batch, height, width, depth} to match
- // kernel implementation of Pad in referenced_ops.h and optimized_ops.h.
+ // kernel implementation of Pad in reference_ops.h and optimized_ops.h.
for (int idx = op_context.dims - 1; idx >= 0; --idx) {
before_padding.push_back(paddings_data[idx * 2]);
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::PadV2(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
-
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ TF_LITE_ENSURE(context, before_padding.size() <= 4); \
+ TF_LITE_ENSURE(context, after_padding.size() <= 4); \
+ tflite::PadParams op_params; \
+ op_params.left_padding_count = before_padding.size(); \
+ op_params.right_padding_count = after_padding.size(); \
+ for (int i = 0; i < op_context.dims; ++i) { \
+ op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
+ op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
+ } \
+ const scalar pad_value_copy = pad_value; \
+ \
+ type::Pad(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), &pad_value_copy, \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
float pad_value = op_context.constant_values == nullptr
@@ -199,7 +209,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
} break;
default:
- context->ReportError(context, "Type is currently not supported by Pad.");
+ context->ReportError(context,
+ "Type %d is currently not supported by Pad.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_PAD
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index f8b9064fbb..f663899713 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -193,7 +193,7 @@ TEST(PadOpTest, TooManyDimensions) {
PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9},
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadOpTest, UnequalDimensions) {
@@ -221,6 +221,15 @@ TEST(PadOpTest, SimpleConstTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
+TEST(PadOpTest, SimpleConst1DTest) {
+ PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2},
+ {TensorType_FLOAT32});
+ m.SetInput({2, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
TEST(PadOpTest, SimpleDynamicTest) {
PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
{TensorType_FLOAT32});
@@ -334,7 +343,7 @@ TEST(PadV2OpTest, TooManyDimensions) {
{TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadV2OpTest, UnequalDimensions) {
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3cb55f19a9..42b6b45d3b 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 645d9f4008..6451142391 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -80,24 +79,24 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto computeOutSize = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size,
+ int stride) -> int {
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - filter_size + stride) / stride
: 0;
};
- int outWidth =
- computeOutSize(width, params->filter_width, params->stride_width);
- int outHeight =
- computeOutSize(height, params->filter_height, params->stride_height);
+ int out_width =
+ compute_out_size(width, params->filter_width, params->stride_width);
+ int out_height =
+ compute_out_size(height, params->filter_height, params->stride_height);
data->padding.height = ComputePadding(params->stride_height, 1, height,
- params->filter_height, outHeight);
+ params->filter_height, out_height);
data->padding.width = ComputePadding(params->stride_width, 1, width,
- params->filter_width, outWidth);
+ params->filter_width, out_width);
if (input->type == kTfLiteUInt8) {
if (pool_type == kAverage || pool_type == kMax) {
@@ -111,12 +110,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
}
}
- TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
- outputSize->data[0] = batches;
- outputSize->data[1] = outHeight;
- outputSize->data[2] = outWidth;
- outputSize->data[3] = channels_out;
- return context->ResizeTensor(context, output, outputSize);
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = batches;
+ output_size->data[1] = out_height;
+ output_size->data[2] = out_width;
+ output_size->data[3] = channels_out;
+ return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
@@ -124,14 +123,21 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
- CalculateActivationRangeFloat(params->activation, &activation_min,
- &activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ CalculateActivationRange(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_AVERAGE_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::AveragePool(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -148,13 +154,19 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, \
- activation_min, activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output))
+#define TF_LITE_AVERAGE_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.quantized_activation_min = activation_min; \
+ op_params.quantized_activation_max = activation_max; \
+ type::AveragePool(op_params, GetTensorShape(input), \
+ GetTensorData<uint8_t>(input), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -168,14 +180,20 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
- CalculateActivationRangeFloat(params->activation, &activation_min,
- &activation_max);
+ CalculateActivationRange(params->activation, &activation_min,
+ &activation_max);
#define TF_LITE_MAX_POOL(type) \
- type::MaxPool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::MaxPool(op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -192,13 +210,19 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_MAX_POOL(type) \
- type::MaxPool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output))
+#define TF_LITE_MAX_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.quantized_activation_min = activation_min; \
+ op_params.quantized_activation_max = activation_max; \
+ type::MaxPool(op_params, GetTensorShape(input), \
+ GetTensorData<uint8_t>(input), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -212,14 +236,20 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
- CalculateActivationRangeFloat(params->activation, &activation_min,
- &activation_max);
-#define TF_LITE_L2_POOL(type) \
- type::L2Pool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ CalculateActivationRange(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_L2_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::L2Pool(op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2_POOL(reference_ops);
} else {
@@ -246,7 +276,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
output);
break;
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -267,7 +298,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
MaxEvalQuantized<kernel_type>(context, node, params, data, input, output);
break;
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
@@ -288,7 +320,8 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
// We don't have a quantized implementation, so just fall through to the
// 'default' case.
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
new file mode 100644
index 0000000000..1e96cc80b1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -0,0 +1,143 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pow {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for pow op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
+ context->ReportError(context, "Unsupported data type %d.", type);
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
+ TfLiteTensor* output, bool requires_broadcast) {
+ if (requires_broadcast) {
+ reference_ops::BroadcastPow4DSlow(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
+ } else {
+ reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
+ }
+}
+
+TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
+ const int64_t num_elements = NumElements(input);
+ const int32_t* data = GetTensorData<int32_t>(input);
+ for (int i = 0; i < num_elements; ++i) {
+ if (data[i] < 0) {
+ context->ReportError(context,
+ "POW does not support negative value for int32.");
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (output->type) {
+ case kTfLiteInt32: {
+ // TensorFlow does not support negative for int32.
+ TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
+ PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
+ break;
+ }
+ case kTfLiteFloat32: {
+ PowImpl<float>(input1, input2, output, data->requires_broadcast);
+ break;
+ }
+ default: {
+ context->ReportError(context, "Unsupported data type: %d", output->type);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace pow
+
+TfLiteRegistration* Register_POW() {
+ static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc
new file mode 100644
index 0000000000..74b3aef5bd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pow_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class PowOpModel : public SingleOpModel {
+ public:
+ PowOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions,
+ CreatePowOptions(builder_).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(PowOpModel, Simple) {
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 1});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8));
+}
+
+TEST(PowOpModel, NegativeAndZeroValue) {
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {0, 2, -7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 0});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1));
+}
+
+TEST(PowOpModel, Float) {
+ PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+ model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8});
+ model.PopulateTensor<float>(model.input2(), {0.5, 2.7, 3.1, 3.2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3)));
+}
+
+TEST(PowOpModel, NegativeFloatTest) {
+ PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+ model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8});
+ model.PopulateTensor<float>(model.input2(), {0.5, -2.7, 3.1, -3.2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3)));
+}
+
+TEST(PowOpModel, BroadcastTest) {
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
new file mode 100644
index 0000000000..4732a37a65
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -0,0 +1,513 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <limits>
+#include <vector>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace reduce {
+
+// This file has reference implementation of reduce_* operators.
+enum KernelType {
+ kReference,
+};
+
+struct OpContext {
+ OpContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
+ input = GetInput(context, node, 0);
+ axis = GetInput(context, node, 1);
+ output = GetOutput(context, node, 0);
+ }
+ TfLiteReducerParams* params;
+ const TfLiteTensor* input;
+ const TfLiteTensor* axis;
+ TfLiteTensor* output;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // Creates two temp tensors to store index and axis for internal
+ // implementation only.
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 3, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
+
+// Resizes the temp tensor that stores resolved axis.
+TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context,
+ TfLiteTensor* resolved_axis) {
+ TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
+ axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
+ return context->ResizeTensor(context, resolved_axis, axis_size);
+}
+
+// Resizes the temp tensor that stores temp sum of reduced elements.
+TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context,
+ TfLiteTensor* temp_sum) {
+ TfLiteIntArray* size = TfLiteIntArrayCreate(1);
+ size->data[0] = static_cast<int>(NumElements(op_context->output));
+ return context->ResizeTensor(context, temp_sum, size);
+}
+
+// Resizes output array based on the input size and resolved axis.
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
+ size_t num_axis = NumElements(op_context->axis);
+ const TfLiteIntArray* input_dims = op_context->input->dims;
+ int input_num_dims = NumDimensions(op_context->input);
+ if (input_num_dims == 0) {
+ return context->ResizeTensor(context, op_context->output,
+ TfLiteIntArrayCreate(0));
+ }
+ const int* axis = GetTensorData<int>(op_context->axis);
+ if (op_context->params->keep_dims) {
+ TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
+ for (int idx = 0; idx < input_num_dims; ++idx) {
+ bool is_axis = false;
+ for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
+ if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
+ is_axis = true;
+ break;
+ }
+ }
+ if (is_axis) {
+ output_dims->data[idx] = 1;
+ } else {
+ output_dims->data[idx] = input_dims->data[idx];
+ }
+ }
+ return context->ResizeTensor(context, op_context->output, output_dims);
+ } else {
+ // Calculates size of reducing axis.
+ int num_reduce_axis = num_axis;
+ for (int i = 0; i < num_axis; ++i) {
+ int current = axis[i];
+ if (current < 0) {
+ current += input_num_dims;
+ }
+ TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims);
+ for (int j = 0; j < i; ++j) {
+ int previous = axis[j];
+ if (previous < 0) {
+ previous += input_num_dims;
+ }
+ if (current == previous) {
+ --num_reduce_axis;
+ break;
+ }
+ }
+ }
+ // Determines output dimensions.
+ TfLiteIntArray* output_dims =
+ TfLiteIntArrayCreate(input_num_dims - num_reduce_axis);
+ int num_skip_axis = 0;
+ for (int idx = 0; idx < input_num_dims; ++idx) {
+ bool is_axis = false;
+ for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
+ if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
+ ++num_skip_axis;
+ is_axis = true;
+ break;
+ }
+ }
+ if (!is_axis) {
+ output_dims->data[idx - num_skip_axis] = input_dims->data[idx];
+ }
+ }
+ return context->ResizeTensor(context, op_context->output, output_dims);
+ }
+}
+
+// Initializes temp tensors to store index and resolved axis.
+TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context) {
+ // Creates a temp index to iterate through input data.
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(3);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
+ scratch_tensor->type = kTfLiteInt32;
+ scratch_tensor->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
+ index_size->data[0] = NumDimensions(op_context->input);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, scratch_tensor, index_size));
+
+ // Creates a temp tensor to store resolved axis given input data.
+ node->temporaries->data[1] = *scratch_tensor_index + 1;
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ resolved_axis->type = kTfLiteInt32;
+ // Creates a temp tensor to store temp sums when calculating mean.
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ switch (op_context->input->type) {
+ case kTfLiteFloat32:
+ temp_sum->type = kTfLiteFloat32;
+ break;
+ case kTfLiteInt32:
+ temp_sum->type = kTfLiteInt64;
+ break;
+ case kTfLiteInt64:
+ temp_sum->type = kTfLiteInt64;
+ break;
+ case kTfLiteUInt8:
+ temp_sum->type = kTfLiteInt32;
+ break;
+ case kTfLiteBool:
+ temp_sum->type = kTfLiteBool;
+ break;
+ default:
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OpContext op_context(context, node);
+ TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
+
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Leaves work to Eval if axis is not constant; else resizes output.
+ if (!IsConstantTensor(op_context.axis)) {
+ SetTensorToDynamic(op_context.output);
+ SetTensorToDynamic(resolved_axis);
+ return kTfLiteOk;
+ }
+ resolved_axis->allocation_type = kTfLiteArenaRw;
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ return kTfLiteOk;
+}
+
+TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteBool);
+ return PrepareSimple(context, node);
+}
+
+TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
+
+ // reduce_mean requires a buffer to store intermediate sum result.
+ OpContext op_context(context, node);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ if (!IsConstantTensor(op_context.axis)) {
+ SetTensorToDynamic(temp_sum);
+ return kTfLiteOk;
+ }
+ temp_sum->allocation_type = kTfLiteArenaRw;
+ return ResizeTempSum(context, &op_context, temp_sum);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
+ }
+
+#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \
+ kernel_type::Mean<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis), \
+ GetTensorData<temp_data_type>(temp_sum))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int, int64_t));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
+ break;
+ case kTfLiteUInt8:
+ if (op_context.input->params.zero_point ==
+ op_context.output->params.zero_point &&
+ op_context.input->params.scale == op_context.output->params.scale) {
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
+ } else {
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::QuantizedMeanOrSum<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point,
+ op_context.input->params.scale, op_context.input->dims->data,
+ op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale,
+ op_context.output->dims->data, op_context.output->dims->size,
+ GetTensorData<int>(op_context.axis), num_axis,
+ op_context.params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<int>(temp_sum), /*compute_sum=*/false));
+ }
+ break;
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_MEAN
+ return kTfLiteOk;
+}
+
+// The underlying logic for Reduce Sum/Prod/Max/Min/Any
+template <typename T>
+TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, T init_value,
+ T reducer(const T current, const T in)) {
+ int64_t num_axis = NumElements(op_context->axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context->output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
+ }
+ if (op_context->input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
+ op_context->output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
+ op_context->output->params.zero_point);
+ }
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::ReduceGeneric<T>(
+ GetTensorData<T>(op_context->input), op_context->input->dims->data,
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ op_context->output->dims->data, op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis), init_value, reducer));
+ return kTfLiteOk;
+}
+
+enum ReduceType {
+ kSum,
+ kProd,
+ kMax,
+ kMin,
+ kAny,
+};
+
+// Eval for determined input type and reduce type.
+template <typename T>
+TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kSum:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(0),
+ [](const T current, const T in) -> T { return in + current; });
+ break;
+ case kProd:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(1),
+ [](const T current, const T in) -> T { return in * current; });
+ break;
+ case kMax:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::lowest(),
+ [](const T current, const T in) -> T {
+ return (in > current) ? in : current;
+ });
+ break;
+ case kMin:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::max(),
+ [](const T current, const T in) -> T {
+ return (in < current) ? in : current;
+ });
+ break;
+ default:
+ return kTfLiteError;
+ }
+}
+
+// Template specialization for bool type
+template <>
+TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kAny:
+ return EvalLogic<bool>(context, node, op_context, false,
+ [](const bool current, const bool in) -> bool {
+ return in || current;
+ });
+ break;
+ default:
+ return kTfLiteError;
+ }
+}
+
+// The entry point that handles input types and then calls template functions to
+// handle ReduceType.
+template <KernelType kernel_type, ReduceType reduce_type>
+TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
+ if (kernel_type != kReference) {
+ return kTfLiteOk;
+ }
+ OpContext op_context(context, node);
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ return EvalType<float>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt32:
+ return EvalType<int>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt64:
+ return EvalType<int64_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteUInt8:
+ return EvalType<uint8_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteBool:
+ return EvalType<bool>(context, node, &op_context, reduce_type);
+ break;
+ default:
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ const auto& input = op_context.input;
+ const auto& output = op_context.output;
+ if (input->type != kTfLiteUInt8 ||
+ (input->params.scale == output->params.scale &&
+ input->params.zero_point == output->params.zero_point)) {
+ return EvalGeneric<kReference, kSum>(context, node);
+ } else {
+ // Rescaling 8bit reduce sum.
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
+ }
+
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::QuantizedMeanOrSum<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point, op_context.input->params.scale,
+ op_context.input->dims->data, op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale, op_context.output->dims->data,
+ op_context.output->dims->size, GetTensorData<int>(op_context.axis),
+ num_axis, op_context.params->keep_dims,
+ GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
+ GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
+ }
+
+ return kTfLiteOk;
+}
+} // namespace reduce
+
+TfLiteRegistration* Register_MEAN_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMeanOrSum,
+ reduce::EvalMean<reduce::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_SUM_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMeanOrSum, reduce::EvalSum};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_PROD_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_MAX_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_MIN_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_ANY_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareAny,
+ reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
+ return &r;
+}
+
+// TODO(kanlig): add optimized implementation of Mean.
+TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
+TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
+TfLiteRegistration* Register_REDUCE_PROD() {
+ return Register_REDUCE_PROD_REF();
+}
+TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
+TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); }
+TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); }
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
new file mode 100644
index 0000000000..fb2ec58ab2
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -0,0 +1,975 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+using ::testing::IsEmpty;
+
+class BaseOpModel : public SingleOpModel {
+ public:
+ void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }
+
+ template <class T>
+ void SetInput(std::vector<T> data) {
+ PopulateTensor(input_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ int Input() { return input_; }
+
+ protected:
+ int input_;
+ int axis_;
+ int output_;
+};
+
+// Model for the tests case where axis is a const tensor.
+class MeanOpConstModel : public BaseOpModel {
+ public:
+ MeanOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class MeanOpDynamicModel : public BaseOpModel {
+ public:
+ MeanOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class SumOpConstModel : public BaseOpModel {
+ public:
+ SumOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class SumOpDynamicModel : public BaseOpModel {
+ public:
+ SumOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class ProdOpConstModel : public BaseOpModel {
+ public:
+ ProdOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class ProdOpDynamicModel : public BaseOpModel {
+ public:
+ ProdOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class MaxOpConstModel : public BaseOpModel {
+ public:
+ MaxOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class MaxOpDynamicModel : public BaseOpModel {
+ public:
+ MaxOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class MinOpConstModel : public BaseOpModel {
+ public:
+ MinOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class MinOpDynamicModel : public BaseOpModel {
+ public:
+ MinOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class AnyOpConstModel : public BaseOpModel {
+ public:
+ AnyOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class AnyOpDynamicModel : public BaseOpModel {
+ public:
+ AnyOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// for quantized Add, the error shouldn't exceed step
+float GetTolerance(int min, int max) { return (max - min) / 255.0; }
+
+// Tests for reduce_mean
+TEST(ConstFloatMeanOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
+}
+
+TEST(ConstFloatMeanOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
+}
+
+TEST(ConstFloatMeanOpTest, Scalar) {
+ std::vector<float> data = {3.27};
+ MeanOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {},
+ {0}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3.27})));
+}
+
+TEST(DynamicFloatMeanOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
+}
+
+TEST(DynamicFloatMeanOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}},
+ true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
+}
+
+TEST(DynamicFloatMeanOpTest, Scale) {
+ std::vector<float> data = {9.527};
+ MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+
+TEST(ConstUint8MeanOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {0.4, 0.4}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MeanOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MeanOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
+ MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::vector<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MeanOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
+ MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MeanOpTest, QuantizedScalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {0.643};
+ MeanOpDynamicModel m({TensorType_UINT8, {}, 0.0, 1.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.643}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MeanOpTest, QuantizedKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {3}, -5.0, 5.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
+}
+
+// Tests for reduce_sum
+
+TEST(ConstFloatSumOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({144, 156})));
+}
+
+TEST(ConstFloatSumOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({84, 100, 116})));
+}
+
+TEST(DynamicFloatSumOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({144, 156})));
+}
+
+TEST(ConstFloatSumOpTest, Scalar) {
+ std::vector<float> data = {17.};
+ SumOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, {0},
+ false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({17.})));
+}
+
+TEST(DynamicFloatSumOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({84, 100, 116})));
+}
+
+TEST(DynamicFloatSumOpTest, Scale) {
+ std::vector<float> data = {9.527};
+ SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8SumOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) {
+ float kQuantizedTolerance = GetTolerance(0.0, 2.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {1.2, 1.2}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8SumOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177},
+ kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8SumOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
+ SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::vector<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8SumOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
+ SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance)));
+}
+
+// Tests for reduce_prod
+
+TEST(ConstFloatProdOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3.162341376e+11, 1.9619905536e+12})));
+}
+
+TEST(ConstFloatProdOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(
+ ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08})));
+}
+
+TEST(DynamicFloatProdOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3.16234143225e+11, 1.9619905536e+12})));
+}
+
+TEST(DynamicFloatProdOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}},
+ true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(
+ ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08})));
+}
+
+TEST(DynamicFloatProdOpTest, Scale) {
+ std::vector<float> data = {9.527};
+ ProdOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+// Tests for reduce_max
+
+TEST(ConstFloatMaxOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({23, 24})));
+}
+
+TEST(ConstFloatMaxOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({20, 22, 24})));
+}
+
+TEST(DynamicFloatMaxOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({23, 24})));
+}
+
+TEST(DynamicFloatMaxOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({20, 22, 24})));
+}
+
+TEST(DynamicFloatMaxOpTest, Scale) {
+ std::vector<float> data = {9.527};
+ MaxOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8MaxOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MaxOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({0.501961, 0.603922}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MaxOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MaxOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({0.4, 0.4, 0.603922}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MaxOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
+ MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::vector<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({1.2902, 0.247059}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MaxOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
+ MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({11.1294, 0.862745}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MaxOpTest, Scalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14};
+ MaxOpDynamicModel m({TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance)));
+}
+
+// Tests for reduce_min
+
+TEST(ConstFloatMinOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({1, 2})));
+}
+
+TEST(ConstFloatMinOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 3, 5})));
+}
+
+TEST(DynamicFloatMinOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({1, 2})));
+}
+
+TEST(DynamicFloatMinOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 3, 5})));
+}
+
+TEST(DynamicFloatMinOpTest, Scalar) {
+ std::vector<float> data = {9.527};
+ MinOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8MinOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MinOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.294117, 0.2}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MinOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MinOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.2, 0.3, 0.5}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MinOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
+ MinOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::vector<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-4.807843, -3.6}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MinOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
+ MinOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({7.427451, -0.164706}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MinOpTest, Scalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14};
+ MinOpDynamicModel m({TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance)));
+}
+
+// Tests for reduce_any
+
+TEST(ConstAnyOpTest, NotKeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}}, {4},
+ {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false, true}));
+}
+
+TEST(ConstAnyOpTest, KeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}}, {2},
+ {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
+}
+
+TEST(DynamicAnyOpTest, NotKeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}},
+ {TensorType_INT32, {4}}, false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false, true}));
+}
+
+TEST(DynamicAnyOpTest, KeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}},
+ {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
+}
+
+TEST(DynamicAnyOpTest, Scalar) {
+ std::vector<bool> data = {false};
+ AnyOpDynamicModel m({TensorType_BOOL, {1}}, {TensorType_BOOL, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 21cc185e9f..9402105fa7 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/util.h"
namespace tflite {
namespace ops {
@@ -21,7 +22,10 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
+TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+TfLiteRegistration* Register_RELU_1();
} // namespace custom
@@ -73,6 +77,7 @@ TfLiteRegistration* Register_SQUEEZE();
TfLiteRegistration* Register_STRIDED_SLICE();
TfLiteRegistration* Register_EXP();
TfLiteRegistration* Register_TOPK_V2();
+TfLiteRegistration* Register_LOG();
TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE();
@@ -80,16 +85,68 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_ARG_MIN();
TfLiteRegistration* Register_GREATER();
TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
TfLiteRegistration* Register_LESS_EQUAL();
TfLiteRegistration* Register_FLOOR();
+TfLiteRegistration* Register_TILE();
TfLiteRegistration* Register_NEG();
+TfLiteRegistration* Register_SUM();
+TfLiteRegistration* Register_REDUCE_PROD();
+TfLiteRegistration* Register_REDUCE_MAX();
+TfLiteRegistration* Register_REDUCE_MIN();
+TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SIN();
TfLiteRegistration* Register_TRANSPOSE_CONV();
+TfLiteRegistration* Register_EXPAND_DIMS();
+TfLiteRegistration* Register_SPARSE_TO_DENSE();
+TfLiteRegistration* Register_EQUAL();
+TfLiteRegistration* Register_NOT_EQUAL();
+TfLiteRegistration* Register_SQRT();
+TfLiteRegistration* Register_RSQRT();
+TfLiteRegistration* Register_SHAPE();
+TfLiteRegistration* Register_POW();
+TfLiteRegistration* Register_FAKE_QUANT();
+TfLiteRegistration* Register_PACK();
+TfLiteRegistration* Register_ONE_HOT();
+TfLiteRegistration* Register_LOGICAL_OR();
+TfLiteRegistration* Register_LOGICAL_AND();
+TfLiteRegistration* Register_LOGICAL_NOT();
+TfLiteRegistration* Register_UNPACK();
+TfLiteRegistration* Register_FLOOR_DIV();
+TfLiteRegistration* Register_SQUARE();
+TfLiteRegistration* Register_ZEROS_LIKE();
+
+TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
+ context->ReportError(
+ context,
+ "Regular TensorFlow ops are not supported by this interpreter. Make sure "
+ "you invoke the Flex delegate before inference.");
+ return kTfLiteError;
+}
+
+const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
+ int version) const {
+ return MutableOpResolver::FindOp(op, version);
+}
+
+const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
+ int version) const {
+ // Return the NULL Op for all ops whose name start with "Flex", allowing
+ // the interpreter to delegate their execution.
+ if (IsFlexOp(op)) {
+ static TfLiteRegistration null_op{
+ nullptr, nullptr, &UnsupportedTensorFlowOp,
+ nullptr, nullptr, BuiltinOperator_CUSTOM,
+ "Flex", 1};
+ return &null_op;
+ }
+ return MutableOpResolver::FindOp(op, version);
+}
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -101,7 +158,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
- AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
AddBuiltin(BuiltinOperator_RNN, Register_RNN());
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
@@ -111,7 +170,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP());
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
Register_EMBEDDING_LOOKUP_SPARSE());
- AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED());
+ AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
@@ -123,7 +184,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORMALIZATION());
- AddBuiltin(BuiltinOperator_LSTM, Register_LSTM());
+ AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
Register_BIDIRECTIONAL_SEQUENCE_LSTM());
AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@@ -144,6 +206,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
+ AddBuiltin(BuiltinOperator_LOG, Register_LOG());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
@@ -151,6 +214,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
@@ -161,12 +225,40 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
+ AddBuiltin(BuiltinOperator_TILE, Register_TILE());
+ AddBuiltin(BuiltinOperator_SUM, Register_SUM());
+ AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
+ AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX());
+ AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN());
+ AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY());
+ AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
+ AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
+ AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
+ AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
+ AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
+ AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
+ AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
+ AddBuiltin(BuiltinOperator_POW, Register_POW());
+ AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
+ AddBuiltin(BuiltinOperator_PACK, Register_PACK());
+ AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
+ AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
+ AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
+ AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
+ AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
+ AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
+ AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
+ AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
+ AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
+ AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
+ AddCustom("TFLite_Detection_PostProcess",
+ tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
} // namespace builtin
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index b928f1b302..61856ab9de 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -16,8 +16,9 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
@@ -26,10 +27,14 @@ namespace builtin {
class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver();
+
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
};
} // namespace builtin
} // namespace ops
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc
new file mode 100644
index 0000000000..abafee2d57
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace relu1 {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+// This is derived from lite/kernels/activations.cc.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const int elements = NumElements(input);
+ const float* in = input->data.f;
+ const float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; ++in, ++out) {
+ *out = std::min(std::max(0.f, *in), 1.f);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace relu1
+
+TfLiteRegistration* Register_RELU_1() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ relu1::Prepare, relu1::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
new file mode 100644
index 0000000000..b1d25a9f50
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_RELU_1();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+ explicit BaseActivationsOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput({input.type, {}});
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+ SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1);
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(FloatActivationsOpTest, Relu1) {
+ FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -2.0, 1.1, -0.1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0, 0.0, 0.2, 0.0, //
+ 0.3, 0.0, 1.0, 0.0, //
+ }));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 3287040695..f41147b2d6 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -25,16 +25,11 @@ namespace builtin {
namespace reshape {
constexpr int kInputTensor = 0;
+constexpr int kShapeTensor = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
-
- // TODO(ahentz): we are often given a tensor with the shape but we only pay
- // attention to what the shape specified in 'params'.
- TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
+TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
+ TfLiteIntArray* output_shape) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -42,37 +37,84 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// special -1 value, meaning it will be calculated automatically based on the
// input. Here we calculate what that dimension should be so that the number
// of output elements in the same as the number of input elements.
- int num_input_elements = 1;
- for (int i = 0; i < NumDimensions(input); ++i) {
- num_input_elements *= SizeOfDimension(input, i);
- }
+ int num_input_elements = NumElements(input);
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
int num_output_elements = 1;
int stretch_dim = -1;
- for (int i = 0; i < params->num_dimensions; ++i) {
- int value = params->shape[i];
+ for (int i = 0; i < output_shape->size; ++i) {
+ int value = output_shape->data[i];
if (value == -1) {
TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
stretch_dim = i;
} else {
num_output_elements *= value;
- output_size->data[i] = value;
}
}
if (stretch_dim != -1) {
- output_size->data[stretch_dim] = num_input_elements / num_output_elements;
- num_output_elements *= output_size->data[stretch_dim];
+ output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
+ num_output_elements *= output_shape->data[stretch_dim];
}
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
- return context->ResizeTensor(context, output, output_size);
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context,
+ TfLiteNode* node) {
+ const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
+ for (int i = 0; i < output_shape->size; ++i) {
+ output_shape->data[i] = shape->data.i32[i];
+ }
+ return ResizeOutput(context, node, output_shape);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Attempt to use shape tensor if it exists.
+ if (NumInputs(node) == 2) {
+ const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
+ // Check if the shape tensor is valid.
+ if (shape->dims->size == 1 && shape->type == kTfLiteInt32) {
+ // Set the output tensor as dynamic if the shape isn't constnat.
+ if (!IsConstantTensor(shape)) {
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ // Shape is constant. Resize now.
+ return ResizeOutputWithShapeTensor(context, node);
+ }
+ }
+ // The function is returned above this line if the shape tensor is usable.
+ // Now fallback to the shape parameter in `TfLiteReshapeParams`.
+ int num_dimensions = params->num_dimensions;
+ if (num_dimensions == 1 && params->shape[0] == 0) {
+ // Legacy tflite models use a shape parameter of [0] to indicate scalars,
+ // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
+ // toco conversion.
+ num_dimensions = 0;
+ }
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
+ for (int i = 0; i < num_dimensions; ++i) {
+ output_shape->data[i] = params->shape[i];
+ }
+ return ResizeOutput(context, node, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node));
+ }
+
memcpy(output->data.raw, input->data.raw, input->bytes);
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc
index aecbd0399f..52d71350d3 100644
--- a/tensorflow/contrib/lite/kernels/reshape_test.cc
+++ b/tensorflow/contrib/lite/kernels/reshape_test.cc
@@ -22,18 +22,27 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::IsEmpty;
class ReshapeOpModel : public SingleOpModel {
public:
ReshapeOpModel(std::initializer_list<int> input_shape,
- std::initializer_list<int> new_shape) {
+ std::initializer_list<int> new_shape,
+ bool use_shape_input_tensor = false) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
+ int shape_input_tensor =
+ use_shape_input_tensor ? AddInput(TensorType_INT32) : -1;
SetBuiltinOp(
BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
.Union());
- BuildInterpreter({input_shape});
+ if (use_shape_input_tensor) {
+ BuildInterpreter({input_shape, GetShape(shape_input_tensor)});
+ PopulateTensor<int>(shape_input_tensor, new_shape);
+ } else {
+ BuildInterpreter({input_shape});
+ }
}
void SetInput(std::initializer_list<float> data) {
@@ -71,6 +80,14 @@ TEST(ReshapeOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
+TEST(ReshapeOpTest, ShapeTensorInput) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}, /*use_shape_input_tensor=*/true);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
+}
+
TEST(ReshapeOpTest, WithStretchDimension) {
ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
@@ -79,6 +96,22 @@ TEST(ReshapeOpTest, WithStretchDimension) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4}));
}
+TEST(ReshapeOpTest, ScalarOutput) {
+ ReshapeOpModel m({1}, {});
+ m.SetInput({3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+}
+
+TEST(ReshapeOpTest, LegacyScalarOutput) {
+ ReshapeOpModel m({1}, {0});
+ m.SetInput({3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index e4bd0f5b85..fb045d15f3 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
- // TODO(ahentz): Our current implementations only support float32.
- TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
// ResizeBilinear creates a float tensor even when the input is made of
// integers.
- output->type = kTfLiteFloat32;
+ output->type = input->type;
if (!IsConstantTensor(size)) {
SetTensorToDynamic(output);
@@ -90,21 +88,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type) \
- type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<float>(output), GetTensorDims(output), \
- params->align_corners)
+#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
+ tflite::ResizeBilinearParams op_params; \
+ op_params.align_corners = params->align_corners; \
+ type::ResizeBilinear(op_params, GetTensorShape(input), \
+ GetTensorData<datatype>(input), GetTensorShape(size), \
+ GetTensorData<int32>(size), GetTensorShape(output), \
+ GetTensorData<datatype>(output))
if (kernel_type == kReference) {
- TF_LITE_RESIZE_BILINEAR(reference_ops);
+ TF_LITE_RESIZE_BILINEAR(reference_ops, float);
}
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
- TF_LITE_RESIZE_BILINEAR(optimized_ops);
+ TF_LITE_RESIZE_BILINEAR(optimized_ops, float);
+ }
+ } else if (output->type == kTfLiteUInt8) {
+ if (kernel_type == kReference) {
+ TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t);
+ }
+ if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
+ TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t);
}
#undef TF_LITE_RESIZE_BILINEAR
} else {
- context->ReportError(context, "Inputs and outputs not all float types.");
+ context->ReportError(context, "Output type is %d, requires float.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
index 4e03f3820a..f4289105f7 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -22,6 +22,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using uint8 = std::uint8_t;
class ResizeBilinearOpModel : public SingleOpModel {
public:
@@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
} else {
size_ = AddInput({TensorType_INT32, {2}});
}
- output_ = AddOutput(TensorType_FLOAT32); // Always float.
+ output_ = AddOutput(input.type);
SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
BuiltinOptions_ResizeBilinearOptions,
CreateResizeBilinearOptions(builder_).Union());
@@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel {
}
}
- void SetInput(std::initializer_list<float> data) {
+ template <typename T>
+ void SetInput(std::initializer_list<T> data) {
PopulateTensor(input_, data);
}
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
private:
int input_;
@@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel {
TEST(ResizeBilinearOpTest, HorizontalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
- m.SetInput({3, 6});
+ m.SetInput<float>({3, 6});
m.SetSize({1, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
- const_m.SetInput({3, 6});
+ const_m.SetInput<float>({3, 6});
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+}
+
+TEST(ResizeBilinearOpTest, HorizontalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}});
+ m.SetInput<uint8>({3, 6});
+ m.SetSize({1, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3});
+ const_m.SetInput<uint8>({3, 6});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+ EXPECT_THAT(const_m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, VerticalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
- m.SetInput({3, 9});
+ m.SetInput<float>({3, 9});
m.SetSize({3, 1});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
- const_m.SetInput({3, 9});
+ const_m.SetInput<float>({3, 9});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+ EXPECT_THAT(const_m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+}
+
+TEST(ResizeBilinearOpTest, VerticalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}});
+ m.SetInput<uint8>({3, 9});
+ m.SetSize({3, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1});
+ const_m.SetInput<uint8>({3, 9});
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
- m.SetInput({
+ m.SetInput<float>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- })));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
- const_m.SetInput({
+ const_m.SetInput<float>({
+ 3, 6, //
+ 9, 12 //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}});
+ m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12 //
+ });
+ m.SetSize({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3});
+ const_m.SetInput<uint8>({
3, 6, //
9, 12 //
});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- })));
+ EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
- m.SetInput({
+ m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
@@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
});
m.SetSize({3, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- 4, 8, 10, //
- 8, 12, 14, //
- 10, 14, 16, //
- })));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
- const_m.SetInput({
+ const_m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- 4, 8, 10, //
- 8, 12, 14, //
- 10, 14, 16, //
- })));
+ EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
- m.SetInput({
+ m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
m.SetSize({3, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 4, 5, 8, 6, 10, //
- 7, 8, 9, 12, 10, 14, //
- 9, 10, 11, 14, 12, 16, //
- })));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
- const_m.SetInput({
+ const_m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 4, 5, 8, 6, 10, //
- 7, 8, 9, 12, 10, 14, //
- 9, 10, 11, 14, 12, 16, //
- })));
+ EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
}
+TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}});
+ m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 12, 16 //
+ });
+ m.SetSize({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 9, 12, 14, //
+ 12, 14, 16, //
+ })));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
+ const_m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 12, 16 //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 9, 12, 14, //
+ 12, 14, 16, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}});
+ m.SetInput<uint8>({
+ 3, 4, 6, 10, //
+ 10, 12, 14, 16, //
+ });
+ m.SetSize({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 9, 10, 12, 11, 14, //
+ 10, 12, 12, 14, 14, 16, //
+ })));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
+ const_m.SetInput<uint8>({
+ 3, 4, 6, 10, //
+ 10, 12, 14, 16, //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 9, 10, 12, 11, 14, //
+ 10, 12, 12, 14, 14, 16, //
+ })));
+}
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 9bc8a1a34a..4780a86ee5 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -70,12 +70,12 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
bool is_rank_one = !HaveSameShapes(input_condition, input_x);
-#define TF_LITE_SELECT(type, op) \
- reference_ops::op(GetTensorData<bool>(input_condition), \
- GetTensorDims(input_condition), \
- GetTensorData<type>(input_x), GetTensorDims(input_x), \
- GetTensorData<type>(input_y), GetTensorDims(input_y), \
- GetTensorData<type>(output), GetTensorDims(output));
+#define TF_LITE_SELECT(type, op) \
+ reference_ops::op(GetTensorShape(input_condition), \
+ GetTensorData<bool>(input_condition), \
+ GetTensorShape(input_x), GetTensorData<type>(input_x), \
+ GetTensorShape(input_y), GetTensorData<type>(input_y), \
+ GetTensorShape(output), GetTensorData<type>(output));
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
@@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
+ case kTfLiteInt16: \
+ TF_LITE_SELECT(int16_t, op); \
+ break; \
case kTfLiteInt32: \
TF_LITE_SELECT(int32_t, op); \
break; \
@@ -97,7 +100,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
break; \
default: \
context->ReportError(context, \
- "Does not support type other than bool|float|int"); \
+ "Does not support type other than bool|float|int, " \
+ "got %d", \
+ type); \
return kTfLiteError; \
}
diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc
index cfe24a5fc9..5b2e61cd29 100644
--- a/tensorflow/contrib/lite/kernels/select_test.cc
+++ b/tensorflow/contrib/lite/kernels/select_test.cc
@@ -88,11 +88,24 @@ TEST(SelectOpTest, SelectUInt8) {
TensorType_UINT8);
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
- model.PopulateTensor<uint8>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<uint8>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<uint8_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<uint8_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<uint8>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutput<uint8_t>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(SelectOpTest, SelectInt16) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_INT16);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
@@ -101,11 +114,11 @@ TEST(SelectOpTest, SelectInt32) {
TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 2, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
@@ -113,11 +126,11 @@ TEST(SelectOpTest, RankOneSelectInt32) {
SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false, true});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 3, 4}));
+ EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 3, 4}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
}
@@ -125,11 +138,11 @@ TEST(SelectOpTest, RankZeroSelectInt32) {
SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
}
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
new file mode 100644
index 0000000000..66d4c9e5c1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace shape {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+template <typename OutType>
+void ExtractShape(const TfLiteTensor* input, OutType* output_data) {
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ output_data[i] = SizeOfDimension(input, i);
+ }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data);
+ switch (params->out_type) {
+ case kTfLiteInt32:
+ output->type = kTfLiteInt32;
+ break;
+ case kTfLiteInt64:
+ output->type = kTfLiteInt64;
+ break;
+ default:
+ context->ReportError(context, "Unknown shape output data type: %d",
+ params->out_type);
+ return kTfLiteError;
+ }
+
+ // Shape always produces a 1-dimensional output tensor, where each output
+ // element is the length of the corresponding input tensor's dimension.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
+ output_size->data[0] = NumDimensions(input);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TFLITE_DCHECK_EQ(NumDimensions(output), 1);
+ TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input));
+
+ switch (output->type) {
+ case kTfLiteInt32:
+ ExtractShape(input, GetTensorData<int32_t>(output));
+ break;
+ case kTfLiteInt64:
+ ExtractShape(input, GetTensorData<int64_t>(output));
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace shape
+
+TfLiteRegistration* Register_SHAPE() {
+ static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc
new file mode 100644
index 0000000000..27b48f4e99
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/shape_test.cc
@@ -0,0 +1,95 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class ShapeOpModel : public SingleOpModel {
+ public:
+ ShapeOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type) {
+ input_ = AddInput(input_type);
+ output_ = AddOutput(output_type);
+ SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions,
+ CreateShapeOptions(builder_, output_type).Union());
+ BuildInterpreter({input_shape});
+ }
+
+ TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
+
+ int input() { return input_; }
+
+ int32_t GetOutputSize() { return GetTensorSize(output_); }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ShapeOpTest, OutTypeInt) {
+ ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
+ TensorType_INT32);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
+}
+
+TEST(ShapeOpTest, OutTypeInt64) {
+ ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
+ TensorType_INT64);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
+}
+
+TEST(ShapeOpTest, ScalarTensor) {
+ ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32);
+ model.Invoke();
+
+ EXPECT_EQ(model.GetOutputSize(), 0);
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0}));
+}
+
+TEST(ShapeOpTest, EmptyTensor) {
+ ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
index c90a15b3a2..de80a4016e 100644
--- a/tensorflow/contrib/lite/kernels/skip_gram.cc
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index b28934e2f7..ccfee41b9c 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -85,7 +85,8 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context,
TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int64_t>(
context, input, begin, size, &output_shape_vector));
} else {
- context->ReportError(context, "Type is currently not supported by Slice.");
+ context->ReportError(
+ context, "Type %d is currently not supported by Slice.", begin->type);
return kTfLiteError;
}
@@ -148,7 +149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetBeginAndSizeVectors<int64_t>(NumDimensions(input), begin, size, &begins,
&sizes);
} else {
- context->ReportError(context, "Type is currently not supported by Slice.");
+ context->ReportError(
+ context, "Type %d is currently not supported by Slice.", begin->type);
return kTfLiteError;
}
@@ -157,10 +159,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
sizes.push_back(1);
}
-#define TF_LITE_SLICE(data_type) \
- optimized_ops::Slice<data_type>( \
- GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+ // The original Slice op implementation only accepted 4-D sizes. That
+ // constraint is, for the present, maintained here.
+ //
+ // The dimensions in the kernel used to be in reverse-order, and TFLite
+ // arranged the begins and sizes vectors accordingly. This macro incorporates
+ // the needed reversing.
+#define TF_LITE_SLICE(data_type) \
+ { \
+ TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
+ tflite::SliceParams op_params; \
+ op_params.begin_count = 4; \
+ op_params.size_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.begin[i] = begins[3 - i]; \
+ op_params.size[i] = sizes[3 - i]; \
+ } \
+ \
+ optimized_ops::Slice<data_type>( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
@@ -179,8 +199,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_SLICE(bool);
break;
default:
- context->ReportError(context,
- "Type is currently not supported by Slice.");
+ context->ReportError(
+ context, "Type %d is currently not supported by Slice.", input->type);
return kTfLiteError;
}
#undef TF_LITE_SLICE
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
index 6c5338ff0f..bd66980226 100644
--- a/tensorflow/contrib/lite/kernels/softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -92,10 +92,11 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
@@ -120,10 +121,11 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 1e35869958..3a10d2e60c 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -113,47 +113,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
-#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \
- type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
+ tflite::SpaceToBatchParams op_params; \
+ op_params.output_offset = pad_value; \
+ type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.paddings), \
GetTensorData<int32_t>(op_context.paddings), \
- GetTensorDims(op_context.paddings), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
}
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
+ op_context.output->params.zero_point);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
+ op_context.output->params.zero_point);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
}
break;
case kTfLiteInt64:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
}
break;
default:
- context->ReportError(context,
- "Type is currently not supported by SpaceToBatch.");
+ context->ReportError(
+ context, "Type %d is currently not supported by SpaceToBatch.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_SPACE_TO_BATCH_ND
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc
index 92a4a037d5..5756573629 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc
@@ -23,6 +23,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::Matcher;
class SpaceToBatchNDOpModel : public SingleOpModel {
public:
@@ -30,6 +31,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
PopulateTensor<float>(input_, data);
}
+ void SetQuantizedInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
void SetBlockShape(std::initializer_list<int> data) {
PopulateTensor<int>(block_shape_, data);
}
@@ -41,6 +46,11 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+
protected:
int input_;
int block_shape_;
@@ -56,18 +66,19 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
// m.Invoke();
class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
public:
- SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
+ SpaceToBatchNDOpConstModel(const TensorData& input,
std::initializer_list<int> block_shape,
- std::initializer_list<int> paddings) {
- input_ = AddInput(TensorType_FLOAT32);
+ std::initializer_list<int> paddings,
+ const TensorData& output) {
+ input_ = AddInput(input);
block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
- output_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
BuiltinOptions_SpaceToBatchNDOptions,
CreateSpaceToBatchNDOptions(builder_).Union());
- BuildInterpreter({input_shape});
+ BuildInterpreter({input.shape});
}
};
@@ -81,26 +92,30 @@ class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
// m.Invoke();
class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel {
public:
- SpaceToBatchNDOpDynamicModel(std::initializer_list<int> input_shape) {
- input_ = AddInput(TensorType_FLOAT32);
+ SpaceToBatchNDOpDynamicModel(const TensorData& input,
+ const TensorData& output) {
+ input_ = AddInput(input);
block_shape_ = AddInput(TensorType_INT32);
paddings_ = AddInput(TensorType_INT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
BuiltinOptions_SpaceToBatchNDOptions,
CreateSpaceToBatchNDOptions(builder_).Union());
- BuildInterpreter({input_shape, {2}, {2, 2}});
+ BuildInterpreter({input.shape, {2}, {2, 2}});
}
};
TEST(SpaceToBatchNDOpTest, InvalidShapeTest) {
- EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}),
- "Cannot allocate tensors");
+ EXPECT_DEATH(
+ SpaceToBatchNDOpConstModel({TensorType_FLOAT32, {1, 3, 3, 1}}, {2, 2},
+ {0, 0, 0, 0}, {TensorType_FLOAT32}),
+ "Cannot allocate tensors");
}
TEST(SpaceToBatchNDOpTest, SimpleConstTest) {
- SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, {2, 2},
+ {0, 0, 0, 0}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
@@ -109,7 +124,8 @@ TEST(SpaceToBatchNDOpTest, SimpleConstTest) {
}
TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 4, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetPaddings({0, 0, 0, 0});
@@ -120,7 +136,8 @@ TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) {
}
TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) {
- SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, {2, 2},
+ {0, 0, 0, 0}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
@@ -129,7 +146,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) {
}
TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetPaddings({0, 0, 0, 0});
@@ -140,7 +158,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) {
}
TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) {
- SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, {3, 2},
+ {1, 0, 2, 0}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
@@ -151,7 +170,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) {
}
TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 5, 2, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 0, 2, 0});
@@ -164,7 +184,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
}
TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) {
- SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
+ SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, {3, 2},
+ {1, 1, 2, 4}, {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
@@ -176,7 +197,8 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) {
}
TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
- SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1});
+ SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 2, 1}},
+ {TensorType_FLOAT32});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 1, 2, 4});
@@ -189,6 +211,88 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
}));
}
+class QuantizedSpaceToBatchNDOpTest : public ::testing::Test {
+ protected:
+ std::vector<Matcher<float>> DequantizedArrayNear(
+ const std::vector<float>& values, const float min, const float max) {
+ const float quantization_tolerance = (max - min) / 255.0;
+ return ArrayFloatNear(values, quantization_tolerance);
+ }
+};
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) {
+ // The test_util and actual quantization code currently ensure that the range
+ // must include zero, but if that ever changes, this test will catch it.
+ EXPECT_DEATH(SpaceToBatchNDOpConstModel m(
+ {TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2},
+ {0, 0, 1, 1, 1, 1, 0, 0}, {TensorType_UINT8, {}, 1.0, 2.0}),
+ ".*Check failed: f_min <= 0.*");
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) {
+ SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0},
+ {3, 2}, {1, 0, 2, 0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
+ 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
+ -1.0, 1.0)));
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
+ SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
+ m.SetBlockShape({3, 2});
+ m.SetPaddings({1, 0, 2, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
+ 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
+ -1.0, 1.0)));
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) {
+ SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0},
+ {3, 2}, {1, 1, 2, 4},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {
+ 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,
+ 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
+ 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0,
+ },
+ -1.0, 1.0)));
+}
+
+TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
+ SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
+ m.SetBlockShape({3, 2});
+ m.SetPaddings({1, 1, 2, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {
+ 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,
+ 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0,
+ 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0,
+ },
+ -1.0, 1.0)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index aafce89512..64c56c017b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -79,10 +79,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
- type::SpaceToDepth<scalar>( \
- GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
- GetTensorData<scalar>(output), GetTensorDims(output))
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ tflite::SpaceToDepthParams op_params; \
+ op_params.block_size = params->block_size; \
+ type::SpaceToDepth(op_params, GetTensorShape(input), \
+ GetTensorData<scalar>(input), GetTensorShape(output), \
+ GetTensorData<scalar>(output))
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
@@ -113,7 +115,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
break;
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
return kTfLiteError;
}
#undef TF_LITE_SPACE_TO_DEPTH
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
new file mode 100644
index 0000000000..843ed0768c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// SparseOutputFullyConnected is a fully connected layer that uses a single
+// row in the weights and bias via a lookup.
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sparse_output_fully_connected {
+
+// Input tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+// Auxiliary input tensor of size { 1 }
+constexpr int kInputLookupTensor = 1;
+
+// Weights tensor of size { n_embeddings , n_input }
+constexpr int kWeightsTensor = 2;
+// Bias tensor of size { n_embeddings }
+constexpr int kBiasTensor = 3;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kScalingFactors = 1,
+ kNumTemporaryTensors = 2
+};
+
+// Struct to hold op data.
+struct OpData {
+ int scratch_tensor_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ context->AddTensors(context, /*tensors_to_add=*/kNumTemporaryTensors,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ // Only support single lookup.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(lookup, 0), 1);
+
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(weights, 1), n_input);
+
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(weights, 0));
+
+ const bool is_hybrid_op =
+ (weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
+
+ if (is_hybrid_op) {
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+
+ // Allocate temporary tensors to store quantized values of input.
+ node->temporaries->data[kInputQuantized] = op_data->scratch_tensor_index;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ // Tell interpreter to allocate temporary tensors to store scaling factors.
+ node->temporaries->data[kScalingFactors] =
+ op_data->scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ const float* weights_ptr = weights->data.f + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, input_ptr_batch, n_batch,
+ output->data.f, /*result_stride=*/1);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+ // Initialize the pointer to storage for quantized values and
+ // scaling factors.
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ int8_t* weights_ptr =
+ reinterpret_cast<int8_t*>(weights->data.uint8) + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Quantize input from float to int8.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= weights->params.scale;
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, quantized_input_ptr_batch,
+ scaling_factors_ptr, n_batch, output->data.f, /*result_stride=*/1);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, lookup, weights, bias, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ return EvalHybrid(input, lookup, weights, bias, scaling_factors,
+ input_quantized, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace sparse_output_fully_connected
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED() {
+ static TfLiteRegistration r = {sparse_output_fully_connected::Init,
+ sparse_output_fully_connected::Free,
+ sparse_output_fully_connected::Prepare,
+ sparse_output_fully_connected::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
new file mode 100644
index 0000000000..365986a5c1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite sparse output fully connected op.
+#include <iomanip>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseSparseOutputFullyConnectedOpModel : public SingleOpModel {
+ public:
+ BaseSparseOutputFullyConnectedOpModel(const TensorData& input,
+ const TensorData& weights,
+ const TensorData& output = {
+ TensorType_FLOAT32}) {
+ input_ = AddInput(input);
+ lookup_ = AddInput({TensorType_INT32, {1}});
+ weights_ = AddInput(weights);
+ int bias_size = GetShape(weights_)[0];
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ output_ = AddOutput(output);
+
+ // Create empty (required) options map.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+
+ SetCustomOp("SPARSE_OUTPUT_FULLY_CONNECTED", fbb.GetBuffer(),
+ Register_SPARSE_OUTPUT_FULLY_CONNECTED);
+ BuildInterpreter({GetShape(input_), GetShape(lookup_), GetShape(weights_),
+ GetShape(bias_)});
+ }
+
+ void SetInput(const std::vector<float>& data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetLookup(const std::vector<int32>& f) { PopulateTensor(lookup_, f); }
+
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int lookup_;
+ int weights_;
+ int bias_;
+ int output_;
+};
+
+class FloatSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
+};
+
+class HybridSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+};
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestFloat) {
+ FloatSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_FLOAT32, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({28}));
+}
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybrid) {
+ HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_UINT8, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ // We get 28.0552 instead of 28.
+ //
+ // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3.
+ // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5.
+ //
+ // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0
+ // gives us the expected result.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553)));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
new file mode 100644
index 0000000000..349fa0bd28
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -0,0 +1,275 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace sparse_to_dense {
+
+constexpr int kIndicesTensor = 0;
+constexpr int kOutputShapeTensor = 1;
+constexpr int kValueInputTensor = 2;
+constexpr int kDefaultValueTensor = 3;
+constexpr int kOutputTensor = 0;
+
+constexpr int kMaxDimensions = 4;
+
+template <typename T>
+TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
+ const int output_dimensions = NumElements(output_shape);
+ TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions);
+ for (int i = 0; i < output_dimensions; ++i) {
+ output_shape_array->data[i] = GetTensorData<T>(output_shape)[i];
+ }
+
+ return context->ResizeTensor(context, output, output_shape_array);
+}
+
+TfLiteStatus CheckDimensionsMatch(TfLiteContext* context,
+ const TfLiteTensor* indices,
+ const TfLiteTensor* output_shape,
+ const TfLiteTensor* values) {
+ switch (NumDimensions(indices)) {
+ case 0:
+ case 1: {
+ if (NumDimensions(values) == 0) {
+ TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values));
+ }
+ TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1);
+ break;
+ }
+ case 2: {
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1),
+ NumElements(output_shape));
+ if (NumDimensions(values) == 0)
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
+ NumElements(values));
+ break;
+ }
+ default:
+ context->ReportError(
+ context, "Wrong indices dimensions %d, should be less than 3.",
+ NumDimensions(indices));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Convert indices into a vector of 4-d vectors.
+// TODO(renjieliu): Revisit here to improve the performance, since multiple
+// allocations of std::vectors will be quite slow on phones.
+template <typename T>
+TfLiteStatus GetIndicesVector(TfLiteContext* context,
+ const TfLiteTensor* indices,
+ const int num_indices,
+ std::vector<std::vector<T>>* indices_vector) {
+ // Note because TfLite will reverse the dimensions, so pad zeros upfront.
+ switch (NumDimensions(indices)) {
+ case 0:
+ case 1: {
+ const auto indices_data = GetTensorData<T>(indices);
+ for (int i = 0; i < num_indices; ++i) {
+ std::vector<T> index({0, 0, 0, indices_data[i]});
+ indices_vector->push_back(index);
+ }
+ break;
+ }
+ case 2: {
+ const int true_dimensions = SizeOfDimension(indices, 1);
+ TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions);
+ for (int i = 0; i < num_indices; ++i) {
+ std::vector<T> index;
+ index.reserve(kMaxDimensions);
+ // Fill the index with 1 up to kMaxDimensions - true_dimensions to
+ // satisfy the needs for 4-dimension index.
+ for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) {
+ index.push_back(0);
+ }
+ for (int j = 0; j < true_dimensions; ++j) {
+ index.push_back(GetTensorData<T>(indices)[i * true_dimensions + j]);
+ }
+
+ indices_vector->push_back(index);
+ }
+ break;
+ }
+ default:
+ context->ReportError(context,
+ "Indices dimensions problem, got %d dimensions",
+ NumDimensions(indices));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus ResizeOutputShape(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
+ if (output_shape->type == kTfLiteInt32) {
+ return Resize<int32_t>(context, output_shape, output);
+ } else if (output_shape->type == kTfLiteInt64) {
+ return Resize<int64_t>(context, output_shape, output);
+ } else {
+ context->ReportError(context, "Dense shape type %d not supported.",
+ output_shape->type);
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+ const TfLiteTensor* output_shape =
+ GetInput(context, node, kOutputShapeTensor);
+ const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
+ const TfLiteTensor* default_value =
+ GetInput(context, node, kDefaultValueTensor);
+
+ // TODO(renjieliu): Handle validate_indices.
+
+ // Indices can be 0-D, 1-D or 2-D.
+ TF_LITE_ASSERT(NumDimensions(indices) >= 0);
+ TF_LITE_ENSURE(context, NumDimensions(indices) < 3);
+ TF_LITE_ASSERT(NumDimensions(output_shape) >= 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
+ // Values can be 0-D or 1-D.
+ TF_LITE_ASSERT(NumDimensions(values) >= 0);
+ TF_LITE_ENSURE(context, NumDimensions(values) < 2);
+
+ TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1);
+
+ TF_LITE_ENSURE(
+ context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 ||
+ output_shape->type == kTfLiteInt64);
+ TF_LITE_ENSURE_EQ(context, values->type, default_value->type);
+
+ // Ensure dimensions match.
+ TF_LITE_ENSURE_OK(
+ context, CheckDimensionsMatch(context, indices, output_shape, values));
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
+
+ if (!IsConstantTensor(output_shape)) {
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputShape(context, output_shape, output);
+}
+
+template <typename T, typename TI>
+TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+ const TfLiteTensor* output_shape =
+ GetInput(context, node, kOutputShapeTensor);
+ const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
+ const TfLiteTensor* default_value =
+ GetInput(context, node, kDefaultValueTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeOutputShape(context, output_shape, output));
+ }
+
+ const int num_indices = SizeOfDimension(indices, 0);
+ const bool value_is_scalar = NumDimensions(values) == 0;
+ std::vector<std::vector<TI>> indices_vector;
+ indices_vector.reserve(num_indices);
+ TF_LITE_ENSURE_OK(context, GetIndicesVector<TI>(context, indices, num_indices,
+ &indices_vector));
+ reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
+ *GetTensorData<T>(default_value),
+ value_is_scalar, GetTensorShape(output),
+ GetTensorData<T>(output));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+ const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
+
+ // Currently only supports float32 and int32.
+ switch (values->type) {
+ case kTfLiteFloat32: {
+ switch (indices->type) {
+ case kTfLiteInt32: {
+ return SparseToDenseImpl<float, int32_t>(context, node);
+ }
+ case kTfLiteInt64: {
+ return SparseToDenseImpl<float, int64_t>(context, node);
+ }
+ default:
+ context->ReportError(
+ context, "Type %d is currently not supported by sparse to dense.",
+ indices->type);
+ return kTfLiteError;
+ }
+ break;
+ }
+ case kTfLiteInt32: {
+ switch (indices->type) {
+ case kTfLiteInt32: {
+ return SparseToDenseImpl<int32_t, int32_t>(context, node);
+ }
+ case kTfLiteInt64: {
+ return SparseToDenseImpl<int32_t, int64_t>(context, node);
+ }
+ default:
+ context->ReportError(
+ context, "Type %d is currently not supported by sparse to dense.",
+ indices->type);
+ return kTfLiteError;
+ }
+ break;
+ }
+ default:
+ context->ReportError(
+ context, "Type %d is currently not supported by sparse to dense.",
+ values->type);
+ return kTfLiteError;
+ }
+}
+
+} // namespace sparse_to_dense
+
+TfLiteRegistration* Register_SPARSE_TO_DENSE() {
+ static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare,
+ sparse_to_dense::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc
new file mode 100644
index 0000000000..a51ec17afc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc
@@ -0,0 +1,155 @@
+
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class SparseToDenseOpModel : public SingleOpModel {
+ public:
+ SparseToDenseOpModel(std::initializer_list<int> indices_shape,
+ std::initializer_list<int> output_shape_shape,
+ std::initializer_list<int> values_shape, T default_value,
+ TensorType tensor_index_type,
+ TensorType tensor_input_type) {
+ indices_ = AddInput(tensor_index_type);
+ output_shape_ = AddInput(TensorType_INT32);
+ values_ = AddInput(tensor_input_type);
+ default_value_ = AddInput(tensor_input_type);
+ output_ = AddOutput(tensor_input_type);
+
+ SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE,
+ BuiltinOptions_SparseToDenseOptions,
+ CreateSparseToDenseOptions(builder_, false).Union());
+ BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}});
+
+ PopulateTensor<T>(default_value_, {default_value});
+ }
+
+ int indices() { return indices_; }
+ int output_shape() { return output_shape_; }
+ int values() { return values_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int indices_;
+ int output_shape_;
+ int values_;
+ int default_value_;
+ int output_;
+};
+
+TEST(SparseToDenseOpModelTest, ZeroDimensionTest) {
+ SparseToDenseOpModel<float> m({1}, {1}, {1}, 0, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {3});
+ m.PopulateTensor<int32_t>(m.output_shape(), {5});
+ m.PopulateTensor<float>(m.values(), {7});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
+TEST(SparseToDenseOpModelTest, OneDimensionTest) {
+ SparseToDenseOpModel<float> m({3}, {1}, {3}, 0, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {1, 3, 5});
+ m.PopulateTensor<int32_t>(m.output_shape(), {7});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7}));
+}
+
+TEST(SparseToDenseOpModelTest, TwoDimensionsTest) {
+ SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, 0, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SparseToDenseOpModelTest, DefaultValueTest) {
+ SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SparseToDenseOpModelTest, IntegerValueTest) {
+ SparseToDenseOpModel<int32_t> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
+ TensorType_INT32);
+ m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<int32_t>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SparseToDenseOpModelTest, Int64IndexTest) {
+ SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT64,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int64_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index c6b94c25be..dab887bf9c 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
auto input_type = op_context.input->type;
- TF_LITE_ENSURE(context,
- input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
+ input_type == kTfLiteUInt8 ||
+ input_type == kTfLiteInt16);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
}
@@ -108,25 +109,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (axis_value < 0) {
axis_value += NumDimensions(op_context.input);
}
- axis_value = RemapDim(NumDimensions(op_context.input), axis_value);
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by
// calculating it in Prepare, unless we defer shape calculation.
// TODO(ahentz): We can improve the optimized_ops version to handle other
// cases too.
-#define TF_LITE_SPLIT(scalar) \
- VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
- if (axis_value == NumDimensions(op_context.input)) { \
- optimized_ops::TensorFlowSplit<FusedActivationFunctionType::kNone, \
- scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \
- all_outputs.dims()); \
- } else { \
- reference_ops::TensorFlowSplit<scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), axis_value, NumOutputs(node), \
- all_outputs.data(), all_outputs.dims()); \
+#define TF_LITE_SPLIT(scalar) \
+ VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
+ tflite::SplitParams op_params; \
+ op_params.num_split = NumOutputs(node); \
+ op_params.axis = axis_value; \
+ if (axis_value == 0) { \
+ optimized_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
+ } else { \
+ reference_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
}
switch (op_context.input->type) {
case kTfLiteFloat32: {
@@ -137,9 +137,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_SPLIT(uint8_t);
break;
}
+ case kTfLiteInt16: {
+ TF_LITE_SPLIT(int16_t);
+ break;
+ }
default:
- context->ReportError(context,
- "Only float32 and uint8 are currently supported.");
+ context->ReportError(
+ context,
+ "Only float32, uint8 and int16 are currently supported, got %d.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_SPLIT
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 09a5662fd9..080c51cd18 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 9417be32b3..06b36dd196 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -57,17 +57,6 @@ struct StridedSliceContext {
int dims;
};
-// Reverse order of bits in the mask to match the expected order in kernel
-inline int ReverseMaskBits(int mask, int num_dimensions) {
- int out = 0;
- for (int dim = 0; dim < num_dimensions; dim++) {
- out <<= 1;
- out += (mask & 1);
- mask >>= 1;
- }
- return out;
-}
-
// This Op only supports 1-4D cases and since we use the reference 4D
// implementation, the 1-3D tensors are mapped to 4D.
const int kMaxDim = 4;
@@ -121,10 +110,19 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
int32_t begin = GetBeginValueAtIndex(op_context, idx);
int32_t end = GetEndValueAtIndex(op_context, idx);
+ // When shrinking an axis, the end position does not matter (and can be
+ // incorrect when negative indexing is used, see Issue #19260). Always use
+ // begin + 1 to generate a length 1 slice, since begin has
+ // already been adjusted for negative indices by GetBeginValueAtIndex.
+ const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
+ if (shrink_axis) {
+ end = begin + 1;
+ }
+
// This is valid for both positive and negative strides
int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
dim_shape = dim_shape < 0 ? 0 : dim_shape;
- if (!(op_context->params->shrink_axis_mask & (1 << idx))) {
+ if (!shrink_axis) {
output_shape_vector.push_back(dim_shape);
}
}
@@ -189,28 +187,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<int32_t> stops;
std::vector<int32_t> strides;
- for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
- stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
- strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
- }
-
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
strides.emplace_back(1);
}
- int begin_mask =
- ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
- int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
+ for (int idx = 0; idx < op_context.dims; ++idx) {
+ starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
+ stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
+ strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
+ }
+
+ int begin_mask = op_context.params->begin_mask << (4 - op_context.dims);
+ int end_mask = op_context.params->end_mask << (4 - op_context.dims);
+ int shrink_axis_mask = op_context.params->shrink_axis_mask
+ << (4 - op_context.dims);
+ TF_LITE_ENSURE_EQ(context, starts.size(), 4);
+ auto op_params = ::tflite::strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, starts, stops, strides);
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice(GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, \
- end_mask, starts, stops, strides, \
- GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+ kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<data_type>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
@@ -235,8 +236,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
default:
context->ReportError(context,
- "Type is currently not supported "
- "by StridedSlice.");
+ "Type %d is currently not supported "
+ "by StridedSlice.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_STRIDED_SLICE
diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
index cc39179bc7..c5d4f9affb 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
@@ -21,7 +21,6 @@ limitations under the License.
namespace tflite {
namespace {
-using ::int32;
using ::testing::ElementsAreArray;
template <typename input_type = float,
@@ -50,14 +49,14 @@ class StridedSliceOpModel : public SingleOpModel {
void SetInput(std::initializer_list<input_type> data) {
PopulateTensor<input_type>(input_, data);
}
- void SetBegin(std::initializer_list<int32> data) {
- PopulateTensor<int32>(begin_, data);
+ void SetBegin(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(begin_, data);
}
- void SetEnd(std::initializer_list<int32> data) {
- PopulateTensor<int32>(end_, data);
+ void SetEnd(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(end_, data);
}
- void SetStrides(std::initializer_list<int32> data) {
- PopulateTensor<int32>(strides_, data);
+ void SetStrides(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(strides_, data);
}
std::vector<input_type> GetOutput() {
@@ -384,6 +383,45 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
}
+TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) {
+ // This is equivalent to tf.range(4)[-1].
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ m.SetInput({0, 1, 2, 3});
+ m.SetBegin({-1});
+ m.SetEnd({0});
+ m.SetStrides({1});
+
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) {
+ // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1].
+ StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
+ m.SetInput({0, 1, 2, 3});
+ m.SetBegin({-2, -1});
+ m.SetEnd({-1, 0});
+ m.SetStrides({1, 1});
+
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) {
+ // This is equivalent to tf.range(4)[:, tf.newaxis][:, -1].
+ StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2);
+ m.SetInput({0, 1, 2, 3});
+ m.SetBegin({0, -1});
+ m.SetEnd({0, 0});
+ m.SetStrides({1, 1});
+
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3}));
+}
+
TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
@@ -395,17 +433,6 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
}
-TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
- StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
- m.SetInput({1, 2, 3, 4});
- m.SetBegin({-2});
- m.SetEnd({-3});
- m.SetStrides({-1});
- m.Invoke();
- EXPECT_TRUE(m.GetOutputShape().empty());
- EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
-}
-
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6});
@@ -538,7 +565,7 @@ TEST(StridedSliceOpTest, RunTwice) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
- StridedSliceOpModel<uint8, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
+ StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 9531ecba98..1be0c83f17 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -78,29 +78,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteSubParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_SUB(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_SUB(reference_ops, BroadcastSub);
+void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_SUB(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int32_t);
+ } else {
+ TF_LITE_SUB(reference_ops, SubWithActivation, int32_t);
+ }
} else {
- TF_LITE_SUB(reference_ops, Sub);
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int32_t);
+ } else {
+ TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_SUB(optimized_ops, BroadcastSub);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, float);
+ } else {
+ TF_LITE_SUB(reference_ops, SubWithActivation, float);
+ }
} else {
- TF_LITE_SUB(optimized_ops, Sub);
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, float);
+ } else {
+ TF_LITE_SUB(optimized_ops, SubWithActivation, float);
+ }
}
}
#undef TF_LITE_SUB
@@ -126,35 +144,45 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32 input1_multiplier;
int input1_shift;
- QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
- &input1_shift);
+ QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier,
+ &input1_multiplier, &input1_shift);
int32 input2_multiplier;
int input2_shift;
- QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
- &input2_shift);
+ QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier,
+ &input2_multiplier, &input2_shift);
int32 output_multiplier;
int output_shift;
- QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
- &output_shift);
+ QuantizeMultiplierSmallerThanOneExp(real_output_multiplier,
+ &output_multiplier, &output_shift);
int32 output_activation_min, output_activation_max;
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_SUB(type, opname) \
- type::opname(left_shift, GetTensorData<uint8_t>(input1), \
- GetTensorDims(input1), input1_offset, input1_multiplier, \
- input1_shift, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, input2_multiplier, \
- input2_shift, output_offset, output_multiplier, output_shift, \
- output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+#define TF_LITE_SUB(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = input2_shift; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = output_multiplier; \
+ op_params.output_shift = output_shift; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
// The quantized version of Sub doesn't support activations, so we
// always use BroadcastSub.
if (kernel_type == kReference) {
- TF_LITE_SUB(reference_ops, BroadcastSub);
+ TF_LITE_SUB(reference_ops, BroadcastSub4DSlow);
} else {
- TF_LITE_SUB(optimized_ops, BroadcastSub);
+ TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow);
}
#undef TF_LITE_SUB
}
@@ -168,14 +196,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalSub<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
- context->ReportError(context,
- "Inputs and outputs not all float|uint8 types.");
+ context->ReportError(
+ context,
+ "output type %d is not supported, requires float|uint8|int32 types.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/sub_test.cc b/tensorflow/contrib/lite/kernels/sub_test.cc
index ff07aeec49..5978c574d3 100644
--- a/tensorflow/contrib/lite/kernels/sub_test.cc
+++ b/tensorflow/contrib/lite/kernels/sub_test.cc
@@ -52,6 +52,13 @@ class FloatSubOpModel : public BaseSubOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerSubOpModel : public BaseSubOpModel {
+ public:
+ using BaseSubOpModel::BaseSubOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
class QuantizedSubOpModel : public BaseSubOpModel {
public:
using BaseSubOpModel::BaseSubOpModel;
@@ -125,6 +132,57 @@ TEST(FloatSubOpModel, WithBroadcast) {
}
}
+TEST(IntegerSubOpModel, NoActivation) {
+ IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3}));
+}
+
+TEST(IntegerSubOpModel, ActivationRELU_N1_TO_1) {
+ IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1}));
+}
+
+TEST(IntegerSubOpModel, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerSubOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5, 11, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerSubOpModel, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerSubOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 308860c299..9903fd5c35 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
+
+// SVDF op that compresses a fully connected op via low-rank matrix
+// factorization. See https://research.google.com/pubs/archive/43813.pdf for
+// details.
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +23,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -32,37 +35,113 @@ namespace ops {
namespace builtin {
namespace svdf {
+namespace {
+
+struct OpData {
+ int scratch_tensor_index;
+ bool float_weights_time_initialized;
+
+ int activation_state_tensor_index;
+};
+
+static inline void ApplyTimeWeightsBiasAndActivation(
+ int batch_size, int memory_size, int num_filters, int num_units, int rank,
+ const TfLiteTensor* weights_time, const TfLiteTensor* bias,
+ TfLiteFusedActivation activation, TfLiteTensor* activation_state,
+ TfLiteTensor* scratch, TfLiteTensor* output) {
+ // Compute matmul(state, weights_time).
+ // The right most column is used to save temporary output (with the size of
+ // num_filters). This is achieved by starting at activation_state->data.f,
+ // and having the stride equal to memory_size.
+ for (int b = 0; b < batch_size; ++b) {
+ float* state_ptr_batch =
+ activation_state->data.f + b * memory_size * num_filters;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::BatchVectorBatchVectorDotProduct(
+ weights_time->data.f, state_ptr_batch, memory_size, num_filters,
+ scratch_ptr_batch, /*result_stride=*/1);
+ }
+
+ // Initialize output with bias if provided.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // Reduction sum.
+ for (int b = 0; b < batch_size; ++b) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
+ num_units, rank);
+ }
+
+ // Apply activation.
+ for (int b = 0; b < batch_size; ++b) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
+ activation, output_ptr_batch);
+ }
+
+ // Left shift the activation_state to make room for next cycle's activation.
+ // TODO(alanchiao): explore collapsing this into a single loop.
+ for (int b = 0; b < batch_size; ++b) {
+ float* state_ptr_batch =
+ activation_state->data.f + b * memory_size * num_filters;
+ for (int f = 0; f < num_filters; ++f) {
+ tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
+ /*shift_value=*/0.0f);
+ state_ptr_batch += memory_size;
+ }
+ }
+}
+
+} // namespace
+
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kStateTensor = 0;
-constexpr int kOutputTensor = 1;
+// This is a variable tensor, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
- return scratch_tensor_index;
+ auto* op_data = new OpData();
+ op_data->float_weights_time_initialized = false;
+ context->AddTensors(context, /*tensors_to_add=*/4,
+ &op_data->scratch_tensor_index);
+ return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
+ delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+ int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+
// Check all the parameters of tensor match within themselves and match the
// input configuration.
const int rank = params->rank;
@@ -79,22 +158,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
}
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- // For each batch, the state is a 2-D tensor: memory_size * num_filters
- // The left most column is used to save current cycle activation.
- // The right most column is used to save temporary output which will be
- // reduced to num_units outputs.
- TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
- state_size_array->data[0] = batch_size;
- state_size_array->data[1] = memory_size * num_filters;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, state, state_size_array));
-
- // Mark state as a persistent tensor.
- state->allocation_type = kTfLiteArenaRwPersistent;
+ // Check the shape of input state tensors.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), batch_size);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
+ memory_size * num_filters);
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
@@ -103,10 +175,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op =
+ (input->type == kTfLiteFloat32 && weights_feature->type == kTfLiteUInt8);
+
// Resize scratch.
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[0] = *scratch_tensor_index;
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(4);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = scratch_tensor_index;
TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2);
scratch_size_array->data[0] = batch_size;
@@ -118,24 +198,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor,
scratch_size_array));
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* weights_feature =
- GetInput(context, node, kWeightsFeatureTensor);
- const TfLiteTensor* weights_time =
- GetInput(context, node, kWeightsTimeTensor);
+ if (is_hybrid_op) {
+ // Tell interpreter to allocate temporary tensors to store quantized values
+ // of input tensors.
+ node->temporaries->data[1] = scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
+ // Tell interpreter to allocate temporary tensors to store scaling factors.
+ node->temporaries->data[2] = scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ // Used to store dequantized weights_time matrix for hybrid computation of
+ // matmul(activation_state, weights_time), which occurs in floating point.
+ node->temporaries->data[3] = scratch_tensor_index + 3;
+ TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3);
+ float_weights_time->type = kTfLiteFloat32;
+ // Persistent so that we can compute the dequantized weights only once.
+ float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
+ if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) {
+ TfLiteIntArray* float_weights_time_size =
+ TfLiteIntArrayCopy(weights_time->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, float_weights_time,
+ float_weights_time_size));
+ }
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input,
+ const TfLiteTensor* weights_feature,
+ const TfLiteTensor* weights_time,
+ const TfLiteTensor* bias, const TfLiteSVDFParams* params,
+ TfLiteTensor* scratch, TfLiteTensor* state,
+ TfLiteTensor* output) {
const int rank = params->rank;
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
@@ -144,69 +256,156 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const int memory_size = weights_time->dims->data[1];
// Clear the activation (state left most column).
- // TODO(ghodrat): Add a test which initialize state with invalid values in
- // left most column and make sure it passes.
- for (int b = 0; b < batch_size; b++) {
+ // TODO(ghodrat): Add a test which initialize activation_state with invalid
+ // values in left most column and make sure it passes.
+ for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
- for (int c = 0; c < num_filters; c++) {
+ for (int c = 0; c < num_filters; ++c) {
float* state_ptr = state_ptr_batch + c * memory_size;
- state_ptr[memory_size - 1] = 0.0;
+ state_ptr[memory_size - 1] = 0.0f;
}
}
// Compute conv1d(inputs, weights_feature).
- // The state left most column is used to save current cycle activation. This
+ // The state right most column is used to save current cycle activation. This
// is achieved by starting at state->data.f[memory_size - 1] and having the
// stride equal to memory_size.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
weights_feature->data.f, num_filters, input_size, input->data.f,
batch_size, &state->data.f[memory_size - 1], memory_size);
- // Compute matmul(state, weights_time).
- // The right most column is used to save temporary output (with the size of
- // num_filters). This is achieved by starting at state->data.f and having the
- // stride equal to memory_size.
- for (int b = 0; b < batch_size; b++) {
+ ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
+ num_units, rank, weights_time, bias,
+ params->activation, state, scratch, output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
+ const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
+ const TfLiteTensor* bias, const TfLiteSVDFParams* params,
+ TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
+ const int rank = params->rank;
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int num_filters = weights_feature->dims->data[0];
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+
+ // Initialize the pointer to input.
+ const float* input_ptr_batch = input->data.f;
+
+ // Initialize the pointer to storage for quantized values and
+ // scaling factors.
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ // Other initializations.
+ const int8_t* weights_feature_ptr =
+ reinterpret_cast<int8_t*>(weights_feature->data.uint8);
+ const float weights_feature_scale = weights_feature->params.scale;
+
+ // Clear the activation (state left most column).
+ // TODO(ghodrat): Add a test which initialize state with invalid values in
+ // the left most column and make sure it passes.
+ for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
- float* scratch_ptr_batch = scratch->data.f + b * num_filters;
- tensor_utils::BatchVectorBatchVectorDotProduct(
- weights_time->data.f, state_ptr_batch, memory_size, num_filters,
- scratch_ptr_batch, /*result_stride=*/1);
+ for (int c = 0; c < num_filters; ++c) {
+ float* state_ptr = state_ptr_batch + c * memory_size;
+ state_ptr[memory_size - 1] = 0.0;
+ }
}
- // Initialize output with bias if provided.
- if (bias) {
- tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
- output->data.f);
- } else {
- tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
- }
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
+ // Quantize input from float to int8.
+ float unused_min, unused_max;
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, input_size,
+ quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= weights_feature_scale;
+ }
- // Reduction sum
- for (int b = 0; b < batch_size; b++) {
- float* output_ptr_batch = output->data.f + b * num_units;
- float* scratch_ptr_batch = scratch->data.f + b * num_filters;
- tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
- num_units, rank);
+ // Compute conv1d(inputs, weights_feature).
+ // The rightmost column of state is used to save the current cycle
+ // activation.
+ // This is achieved by starting at state->data.f[memory_size - 1]
+ // and having the stride equal to memory_size.
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch,
+ scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1],
+ memory_size);
}
- // Apply activation.
- for (int b = 0; b < batch_size; b++) {
- float* output_ptr_batch = output->data.f + b * num_units;
- tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
- params->activation, output_ptr_batch);
- }
+ // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying
+ // time weights so that the inner loop multiplies eight elements at a time.
+ ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
+ num_units, rank, weights_time, bias,
+ params->activation, state, scratch, output);
+ return kTfLiteOk;
+}
- // Right shift the state.
- for (int b = 0; b < batch_size; b++) {
- float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
- for (int f = 0; f < num_filters; f++) {
- tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
- /*shift_value=*/0.0);
- state_ptr_batch += memory_size;
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* weights_feature =
+ GetInput(context, node, kWeightsFeatureTensor);
+ const TfLiteTensor* weights_time =
+ GetInput(context, node, kWeightsTimeTensor);
+ const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+
+ TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (weights_feature->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(context, node, input, weights_feature, weights_time,
+ bias, params, scratch, activation_state, output);
+ break;
}
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* float_weights_time =
+ GetTemporary(context, node, /*index=*/3);
+
+ // Dequantize weights time.
+ // TODO(alanchiao): this dequantization initialization only needs to
+ // happen once per model and should theoretically be placed in either Init
+ // or Prepare. However, TFLite doesn't allocate float_weights_time until
+ // the Eval function.
+ // TODO(alanchiao): refactor logic out into dequantize function.
+ if (!op_data->float_weights_time_initialized) {
+ const float dequantization_scale = weights_time->params.scale;
+ const int8_t* weights_time_ptr =
+ reinterpret_cast<int8_t*>(weights_time->data.uint8);
+ for (int i = 0; i < NumElements(float_weights_time); ++i) {
+ float_weights_time->data.f[i] =
+ weights_time_ptr[i] * dequantization_scale;
+ }
+ op_data->float_weights_time_initialized = true;
+ }
+ return EvalHybrid(context, node, input, weights_feature,
+ float_weights_time, bias, params, scratch,
+ scaling_factors, input_quantized, activation_state,
+ output);
+ break;
+ }
+ default:
+ context->ReportError(context, "Type %d not currently supported.",
+ weights_feature->type);
+ return kTfLiteError;
}
- return kTfLiteOk;
}
} // namespace svdf
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
index 0f166dc69b..6d60dc63f4 100644
--- a/tensorflow/contrib/lite/kernels/svdf_test.cc
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -126,28 +126,35 @@ static float svdf_golden_output_rank_2[] = {
};
// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
-class SVDFOpModel : public SingleOpModel {
+class BaseSVDFOpModel : public SingleOpModel {
public:
- SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank)
+ BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
+ int rank,
+ TensorType weights_feature_type = TensorType_FLOAT32,
+ TensorType weights_time_type = TensorType_FLOAT32)
: batches_(batches),
units_(units),
input_size_(input_size),
memory_size_(memory_size),
rank_(rank) {
input_ = AddInput(TensorType_FLOAT32);
- weights_feature_ = AddInput(TensorType_FLOAT32);
- weights_time_ = AddInput(TensorType_FLOAT32);
+ weights_feature_ = AddInput(weights_feature_type);
+ weights_time_ = AddInput(weights_time_type);
bias_ = AddNullInput();
- state_ = AddOutput(TensorType_FLOAT32);
+ const int num_filters = units * rank;
+ activation_state_ = AddInput(
+ TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
+ /*is_variable=*/true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
BuildInterpreter({
- {batches_, input_size_}, // Input tensor
- {units_ * rank, input_size_}, // weights_feature tensor
- {units_ * rank, memory_size_}, // weights_time tensor
- {units_} // bias tensor
+ {batches_, input_size_}, // input tensor
+ {units_ * rank, input_size_}, // weights_feature tensor
+ {units_ * rank, memory_size_}, // weights_time tensor
+ {units_}, // bias tensor
+ {batches, memory_size * num_filters} // activation_state tensor
});
}
@@ -166,15 +173,6 @@ class SVDFOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- // Resets the state of SVDF op by filling it with 0's.
- void ResetState() {
- const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
// Extracts the output tensor from the SVDF op.
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -182,12 +180,12 @@ class SVDFOpModel : public SingleOpModel {
int num_units() { return units_; }
int num_batches() { return batches_; }
- private:
+ protected:
int input_;
int weights_feature_;
int weights_time_;
int bias_;
- int state_;
+ int activation_state_;
int output_;
int batches_;
@@ -197,7 +195,61 @@ class SVDFOpModel : public SingleOpModel {
int rank_;
};
-TEST(SVDFOpTest, BlackBoxTestRank1) {
+class SVDFOpModel : public BaseSVDFOpModel {
+ public:
+ using BaseSVDFOpModel::BaseSVDFOpModel;
+};
+
+class HybridSVDFOpModel : public BaseSVDFOpModel {
+ public:
+ HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
+ int rank)
+ : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
+ TensorType_UINT8, TensorType_UINT8) {}
+
+ void SetWeightsFeature(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_feature_, f);
+ }
+
+ void SetWeightsTime(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_time_, f);
+ }
+};
+
+class SVDFOpTest : public ::testing::Test {
+ protected:
+ void VerifyGoldens(float golden_input[], float golden_output[],
+ int golden_size, BaseSVDFOpModel* svdf,
+ float tolerance = 1e-5) {
+ const int svdf_num_batches = svdf->num_batches();
+ const int svdf_input_size = svdf->input_size();
+ const int svdf_num_units = svdf->num_units();
+ const int input_sequence_size =
+ golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF
+ // op and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start =
+ golden_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf->SetInput(0, batch_start, batch_end);
+
+ svdf->Invoke();
+
+ const float* golden_start =
+ golden_output + i * svdf_num_units * svdf_num_batches;
+ const float* golden_end =
+ golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+TEST_F(SVDFOpTest, BlackBoxTestRank1) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1);
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
@@ -217,32 +269,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) {
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
- svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
+ &svdf);
}
-TEST(SVDFOpTest, BlackBoxTestRank2) {
+TEST_F(SVDFOpTest, BlackBoxTestRank2) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2);
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
@@ -277,29 +308,73 @@ TEST(SVDFOpTest, BlackBoxTestRank2) {
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
- svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
+ &svdf);
+}
+
+TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) {
+ HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/1);
+ svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
+
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
+ &svdf,
+ /*tolerance=*/0.002945);
+}
+
+TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) {
+ HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/2);
+ svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
+ 0.12416199, 0.15785322, 0.27901134, 0.3905206,
+ 0.21931258, -0.36137494, -0.10640851, 0.31053296,
+ -0.36118156, -0.0976817, -0.36916667, 0.22197971,
+ 0.15294972, 0.38031587, 0.27557442, 0.39635518,
+ -0.21580373, -0.06634006, -0.02702999, 0.27072677});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
+
+ -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
+ 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
+
+ -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
+ 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
+
+ -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
+ -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
+
+ 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
+ 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
+
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
+ &svdf,
+ /*tolerance=*/0.00625109);
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 1a01ee0936..05a7c23ba1 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -32,8 +32,8 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
return matchers;
}
-int SingleOpModel::AddInput(const TensorData& t) {
- int id = AddTensor<float>(t, {});
+int SingleOpModel::AddInput(const TensorData& t, bool is_variable) {
+ int id = AddTensor<float>(t, {}, is_variable);
inputs_.push_back(id);
return id;
}
@@ -74,8 +74,8 @@ void SingleOpModel::SetCustomOp(
CustomOptionsFormat_FLEXBUFFERS));
}
-void SingleOpModel::BuildInterpreter(
- std::vector<std::vector<int>> input_shapes) {
+void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+ bool allow_fp32_relax_to_fp16) {
auto opcodes = builder_.CreateVector(opcodes_);
auto operators = builder_.CreateVector(operators_);
auto tensors = builder_.CreateVector(tensors_);
@@ -112,8 +112,17 @@ void SingleOpModel::BuildInterpreter(
if (shape.empty()) continue;
CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
}
+
+ interpreter_->SetAllowFp16PrecisionForFp32(allow_fp32_relax_to_fp16);
+
+ // Modify delegate with function.
+ if (apply_delegate_fn_) {
+ apply_delegate_fn_(interpreter_.get());
+ }
+
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
+ interpreter_->ResetVariableTensors();
}
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index 55edc97d19..84deb0e0e8 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -114,13 +114,22 @@ class SingleOpModel {
SingleOpModel() {}
~SingleOpModel() {}
+ // Set a function callback that is run right after graph is prepared
+ // that allows applying external delegates. This is useful for testing
+ // other runtimes like NN API or GPU.
+ void SetApplyDelegate(std::function<void(Interpreter*)> apply_delegate_fn) {
+ apply_delegate_fn_ = apply_delegate_fn;
+ }
+
// Copying or assignment is disallowed to simplify ownership semantics.
SingleOpModel(const SingleOpModel&) = delete;
SingleOpModel& operator=(const SingleOpModel&) = delete;
// Add a TensorType input tensor and return its index.
- int AddInput(TensorType type) { return AddInput(TensorData{type}); }
- int AddInput(const TensorData& t);
+ int AddInput(TensorType type, bool is_variable = false) {
+ return AddInput(TensorData{type}, is_variable);
+ }
+ int AddInput(const TensorData& t, bool is_variable = false);
// Templated version of AddConstInput().
template <typename T>
@@ -139,20 +148,18 @@ class SingleOpModel {
int AddOutput(const TensorData& t);
template <typename T>
- void QuantizeAndPopulate(int index, std::initializer_list<float> data) {
+ void QuantizeAndPopulate(int index, const std::vector<float>& data) {
TfLiteTensor* t = interpreter_->tensor(index);
auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
PopulateTensor(index, 0, q.data(), q.data() + q.size());
}
- void SymmetricQuantizeAndPopulate(int index,
- std::initializer_list<float> data) {
+ void SymmetricQuantizeAndPopulate(int index, const std::vector<float>& data) {
TfLiteTensor* t = interpreter_->tensor(index);
- std::vector<float> values(data);
- const int length = values.size();
+ const int length = data.size();
std::vector<int8_t> q(length);
float min, max, scaling_factor;
- tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min,
+ tensor_utils::SymmetricQuantizeFloats(data.data(), length, q.data(), &min,
&max, &scaling_factor);
// Update quantization params.
t->params.scale = scaling_factor;
@@ -175,7 +182,8 @@ class SingleOpModel {
// Build the interpreter for this model. Also, resize and allocate all
// tensors given the shapes of the inputs.
- void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
+ void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+ bool allow_fp32_relax_to_fp16 = false);
void Invoke();
@@ -189,8 +197,22 @@ class SingleOpModel {
}
// Populate the tensor given its index.
+ // TODO(b/110696148) clean up and merge with vector-taking variant below.
+ template <typename T>
+ void PopulateTensor(int index, const std::initializer_list<T>& data) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ CHECK(v) << "No tensor with index '" << index << "'.";
+ for (T f : data) {
+ *v = f;
+ ++v;
+ }
+ }
+
+ // Populate the tensor given its index.
+ // TODO(b/110696148) clean up and merge with initializer_list-taking variant
+ // above.
template <typename T>
- void PopulateTensor(int index, std::initializer_list<T> data) {
+ void PopulateTensor(int index, const std::vector<T>& data) {
T* v = interpreter_->typed_tensor<T>(index);
CHECK(v) << "No tensor with index '" << index << "'.";
for (T f : data) {
@@ -253,7 +275,8 @@ class SingleOpModel {
}
template <typename T>
- int AddTensor(TensorData t, std::initializer_list<T> data) {
+ int AddTensor(TensorData t, std::initializer_list<T> data,
+ bool is_variable = false) {
int id = tensors_.size();
// This is slightly different depending on whether we are adding a
@@ -270,6 +293,9 @@ class SingleOpModel {
} else if (t.type == TensorType_INT32) {
std::tie(t.scale, t.zero_point) =
QuantizationParams<int32_t>(t.min, t.max);
+ } else if (t.type == TensorType_INT16) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<int16_t>(t.min, t.max);
} else {
LOG(FATAL) << "No support for the requested quantized type";
}
@@ -302,7 +328,7 @@ class SingleOpModel {
tensors_.push_back(CreateTensor(builder_,
builder_.CreateVector<int>(t.shape), t.type,
/*buffer=*/buffer_id,
- /*name=*/0, q_params));
+ /*name=*/0, q_params, is_variable));
tensor_data_[id] = t;
@@ -317,6 +343,9 @@ class SingleOpModel {
std::vector<flatbuffers::Offset<Operator>> operators_;
std::vector<flatbuffers::Offset<Buffer>> buffers_;
std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
+ // A function pointer that gets called after the interpreter is created but
+ // before evaluation happens. This is useful for applying a delegate.
+ std::function<void(Interpreter*)> apply_delegate_fn_;
};
// Base class for single op unit tests.
diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/contrib/lite/kernels/test_util_test.cc
index 1e10e89061..2365803472 100644
--- a/tensorflow/contrib/lite/kernels/test_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/test_util_test.cc
@@ -22,22 +22,22 @@ using ::testing::ElementsAreArray;
TEST(TestUtilTest, QuantizeVector) {
std::vector<float> data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0};
- auto q_data = Quantize<uint8>(data, /*scale=*/1.0, /*zero_point=*/0);
- std::vector<uint8> expected = {0, 0, 0, 1, 1, 255};
+ auto q_data = Quantize<uint8_t>(data, /*scale=*/1.0, /*zero_point=*/0);
+ std::vector<uint8_t> expected = {0, 0, 0, 1, 1, 255};
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
TEST(TestUtilTest, QuantizeVectorScalingDown) {
std::vector<float> data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0};
- auto q_data = Quantize<uint8>(data, /*scale=*/10.0, /*zero_point=*/0);
- std::vector<uint8> expected = {0, 0, 0, 0, 0, 100};
+ auto q_data = Quantize<uint8_t>(data, /*scale=*/10.0, /*zero_point=*/0);
+ std::vector<uint8_t> expected = {0, 0, 0, 0, 0, 100};
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
TEST(TestUtilTest, QuantizeVectorScalingUp) {
std::vector<float> data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0};
- auto q_data = Quantize<uint8>(data, /*scale=*/0.1, /*zero_point=*/0);
- std::vector<uint8> expected = {0, 0, 0, 5, 10, 255};
+ auto q_data = Quantize<uint8_t>(data, /*scale=*/0.1, /*zero_point=*/0);
+ std::vector<uint8_t> expected = {0, 0, 0, 5, 10, 255};
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
new file mode 100644
index 0000000000..49421eb870
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -0,0 +1,195 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace tile {
+
+constexpr int kInputTensor = 0;
+constexpr int kInputMultipliers = 1;
+constexpr int kOutputTensor = 0;
+
+namespace {
+template <typename T>
+TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
+ const TfLiteTensor* multipliers,
+ int num_dimensions) {
+ const T* multipliers_v = GetTensorData<T>(multipliers);
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
+ for (int i = 0; i < num_dimensions; ++i) {
+ output_shape->data[i] = shape.data[i] * multipliers_v[i];
+ }
+ return output_shape;
+}
+
+TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
+
+ const int num_dimensions = NumDimensions(input);
+ const int num_multipliers = NumElements(multipliers);
+ TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers);
+ switch (multipliers->type) {
+ case kTfLiteInt32:
+ return context->ResizeTensor(
+ context, output,
+ MultiplyShapeDims<int32_t>(*input->dims, multipliers,
+ num_dimensions));
+ case kTfLiteInt64:
+ return context->ResizeTensor(
+ context, output,
+ MultiplyShapeDims<int64_t>(*input->dims, multipliers,
+ num_dimensions));
+ default:
+ context->ReportError(context, "Tile not supported multiply tensor type.");
+ return kTfLiteError;
+ }
+}
+
+template <typename T>
+void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier,
+ T* out_data) {
+ for (int i = 0; i < multiplier; ++i) {
+ const T* in_end = in_data + in_size;
+ T* new_out_data = std::copy(in_data, in_end, out_data);
+ in_data = out_data;
+ out_data = new_out_data;
+ }
+}
+
+template <typename T, typename M>
+std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
+ const T* in_data, const M* multipliers,
+ T* out_data, int dimension) {
+ const int dimension_size = in_dimensions.data[dimension];
+ if (dimension == in_dimensions.size - 1) {
+ CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
+ out_data);
+ return std::make_pair(
+ dimension_size,
+ dimension_size * static_cast<int>(multipliers[dimension]));
+ }
+ int total_stride_size = 0, total_tiled_stride_size = 0;
+ const T* copy_from_data = in_data;
+ T* copy_to_data = out_data;
+ for (int i = 0; i < dimension_size; ++i) {
+ int stride_size = 0, tiled_stride_size = 0;
+ std::tie(stride_size, tiled_stride_size) =
+ TileOneDimension(in_dimensions, copy_from_data, multipliers,
+ copy_to_data, dimension + 1);
+ copy_from_data += stride_size;
+ copy_to_data += tiled_stride_size;
+ total_stride_size += stride_size;
+ total_tiled_stride_size += tiled_stride_size;
+ }
+ CopyMultipleTimes(out_data, total_tiled_stride_size,
+ multipliers[dimension] - 1,
+ out_data + total_tiled_stride_size);
+ return std::make_pair(total_stride_size,
+ total_tiled_stride_size * multipliers[dimension]);
+}
+
+template <typename T>
+void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
+ const TfLiteTensor* multipliers, TfLiteTensor* out_data) {
+ // Doing recursively tiling from top to down dimension.
+ switch (multipliers->type) {
+ case kTfLiteInt32:
+ TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
+ GetTensorData<int32_t>(multipliers),
+ GetTensorData<T>(out_data), 0);
+ break;
+ case kTfLiteInt64:
+ TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
+ GetTensorData<int64_t>(multipliers),
+ GetTensorData<T>(out_data), 0);
+ break;
+ default:
+ break;
+ }
+}
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
+ // Only int32 and int64 multipliers type is supported.
+ TF_LITE_ENSURE_MSG(context,
+ (multipliers->type == kTfLiteInt32) ||
+ (multipliers->type == kTfLiteInt64),
+ "Tile only supports int32 and int64 mutlipliers.");
+
+ if (IsConstantTensor(multipliers)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
+ } else {
+ SetTensorToDynamic(output);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
+
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
+ }
+
+ switch (output->type) {
+ case kTfLiteFloat32:
+ Tile<float>(*(input->dims), input, multipliers, output);
+ break;
+ case kTfLiteUInt8:
+ Tile<uint8_t>(*(input->dims), input, multipliers, output);
+ break;
+ case kTfLiteInt32:
+ Tile<int32_t>(*(input->dims), input, multipliers, output);
+ break;
+ case kTfLiteInt64:
+ Tile<int64_t>(*(input->dims), input, multipliers, output);
+ break;
+ default:
+ context->ReportError(context, "Type is currently not supported by Tile.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace tile
+TfLiteRegistration* Register_TILE() {
+ static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval};
+ return &r;
+}
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
new file mode 100644
index 0000000000..e73ca7b750
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -0,0 +1,256 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+class TileOpModel : public SingleOpModel {
+ public:
+ TileOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType multiply_type) {
+ input_ = AddInput(input_type);
+ multipliers_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(input_type);
+ SetBuiltinOp(BuiltinOperator_TILE, BuiltinOptions_TileOptions, 0);
+ BuildInterpreter({input_shape, {static_cast<int>(input_shape.size())}});
+ }
+
+ void SetInputFloat(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ void SetInputUInt8(std::initializer_list<uint8_t> data) {
+ PopulateTensor<uint8_t>(input_, data);
+ }
+
+ void SetInputInt32(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(input_, data);
+ }
+
+ void SetInputInt64(std::initializer_list<int64_t> data) {
+ PopulateTensor<int64_t>(input_, data);
+ }
+
+ void SetMultipliers(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(multipliers_, data);
+ }
+
+ std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
+
+ std::vector<uint8_t> GetOutputUInt8() { return ExtractVector<uint8_t>(output_); }
+
+ std::vector<int32_t> GetOutputInt32() { return ExtractVector<int32_t>(output_); }
+
+ std::vector<int64_t> GetOutputInt64() {
+ return ExtractVector<int64_t>(output_);
+ }
+
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int multipliers_;
+ int output_;
+};
+
+TEST(TileTest, Float32Vector) {
+ TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32);
+ m.SetInputFloat({1.f, 2.f, 3.f});
+ m.SetMultipliers({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(),
+ ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f}));
+}
+
+TEST(TileTest, Float32Matrix) {
+ TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32);
+ m.SetInputFloat({
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Float32HighDimension) {
+ TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32);
+ m.SetInputFloat({
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ });
+ m.SetMultipliers({2, 3, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutputFloat(),
+ ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
+ 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f,
+ 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
+ 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 6, 3}));
+}
+
+TEST(TileTest, Uint8Matrix) {
+ TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32);
+ m.SetInputUInt8({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Int32Matrix) {
+ TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32);
+ m.SetInputInt32({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Int64Matrix) {
+ TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32);
+ m.SetInputInt64({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Int64Matrix64Multipliers) {
+ TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64);
+ m.SetInputInt64({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 0feb42b85b..6c38b6739e 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -56,11 +56,13 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
output_values_shape->data[num_dimensions - 1] = k;
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
+ // Force output types.
+ output_indexes->type = kTfLiteInt32;
+ output_values->type = input->type;
auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
TfLiteIntArray* delete_on_error) {
TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
if (status != kTfLiteOk) {
- TfLiteIntArrayFree(new_size);
if (delete_on_error != nullptr) {
TfLiteIntArrayFree(delete_on_error);
}
@@ -214,7 +216,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output_values->data.i64);
break;
default:
- context->ReportError(context, "Type is currently not supported by TopK.");
+ context->ReportError(context,
+ "Type %d is currently not supported by TopK.",
+ output_values->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 212f8acc76..16106fdafe 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -42,32 +42,32 @@ class TopKV2OpModel : public SingleOpModel {
PopulateTensor<float>(input_, data);
}
- void SetInputUInt8(std::initializer_list<uint8> data) {
- PopulateTensor<uint8>(input_, data);
+ void SetInputUInt8(std::initializer_list<uint8_t> data) {
+ PopulateTensor<uint8_t>(input_, data);
}
- void SetInputInt32(std::initializer_list<int32> data) {
- PopulateTensor<int32>(input_, data);
+ void SetInputInt32(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(input_, data);
}
void SetInputInt64(std::initializer_list<int64_t> data) {
PopulateTensor<int64_t>(input_, data);
}
- std::vector<int32> GetIndexes() {
- return ExtractVector<int32>(output_indexes_);
+ std::vector<int32_t> GetIndexes() {
+ return ExtractVector<int32_t>(output_indexes_);
}
std::vector<float> GetValuesFloat() {
return ExtractVector<float>(output_values_);
}
- std::vector<uint8> GetValuesUInt8() {
- return ExtractVector<uint8>(output_values_);
+ std::vector<uint8_t> GetValuesUInt8() {
+ return ExtractVector<uint8_t>(output_values_);
}
- std::vector<int32> GetValuesInt32() {
- return ExtractVector<int32>(output_values_);
+ std::vector<int32_t> GetValuesInt32() {
+ return ExtractVector<int32_t>(output_values_);
}
std::vector<int64_t> GetValuesInt64() {
@@ -119,7 +119,7 @@ TEST(TopKV2OpTest, VectorFloat) {
EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2})));
}
-// Check that uint8 works.
+// Check that uint8_t works.
TEST(TopKV2OpTest, TypeUint8) {
TopKV2OpModel m({2, 3}, TensorType_UINT8, 2);
m.SetInputUInt8({1, 2, 3, 251, 250, 249});
@@ -128,7 +128,7 @@ TEST(TopKV2OpTest, TypeUint8) {
EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250}));
}
-// Check that int32 works.
+// Check that int32_t works.
TEST(TopKV2OpTest, TypeInt32) {
TopKV2OpModel m({2, 3}, TensorType_INT32, 2);
m.SetInputInt32({1, 2, 3, 10251, 10250, 10249});
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 8316a23c18..e42a30420b 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -92,26 +92,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
- // Reverse the permuted axes and convert to 4D due to the way Dims are
- // constructed in GetTensorDims.
const int* perm_data = GetTensorData<int32_t>(op_context.perm);
const int size = op_context.perm->dims->data[0];
- const int kOutputDimensionNum = 4;
- int reversed_perm[kOutputDimensionNum];
-
- for (int output_k = 0, input_k = size - 1; output_k < size;
- ++output_k, --input_k) {
- reversed_perm[output_k] = size - perm_data[input_k] - 1;
- }
- for (int k = size; k < kOutputDimensionNum; ++k) {
- reversed_perm[k] = k;
+ TransposeParams params;
+ params.perm_count = size;
+ for (int i = 0; i < size; ++i) {
+ params.perm[i] = perm_data[i];
}
#define TF_LITE_TRANSPOSE(type, scalar) \
- type::Transpose(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), reversed_perm)
+ type::Transpose(params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
@@ -136,7 +129,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
default:
context->ReportError(context,
- "Type is currently not supported by Transpose.");
+ "Type %d is currently not supported by Transpose.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_TRANSPOSE
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 3c99661029..1c4a5ee91d 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -70,7 +69,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
- // Currenlty only supports float32.
+ // Currently only supports float32.
const TfLiteType data_type = input->type;
TF_LITE_ENSURE(context, data_type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, data_type);
@@ -79,7 +78,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Ensure that weights and inputs have the same channel dimension.
// Note: TOCO will reorder weights in the following format: OHWI.
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
- SizeOfDimension(weights, 0));
+ SizeOfDimension(weights, 3));
if (!IsConstantTensor(output_shape)) {
SetTensorToDynamic(output);
@@ -118,13 +117,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
- case kTfLiteFloat32:
- optimized_ops::TransposeConv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
- stride_height, padding_size.width, padding_size.height,
- GetTensorData<float>(output), GetTensorDims(output));
+ case kTfLiteFloat32: {
+ tflite::ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = padding_size.width;
+ op_params.padding_values.height = padding_size.height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ reference_ops::TransposeConv(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(output), GetTensorData<float>(output),
+ // Last two args specify im2col which reference_ops ignores.
+ // (Note this does not lead to a performance regression, as the
+ // previous optimized version was just a copy of the reference code.)
+ // TODO(b/110208176): Allocate im2col tensors and switch to
+ // optimized_ops.
+ GetTensorShape(output), GetTensorData<float>(output));
break;
+ }
default:
context->ReportError(context, "Type %d, not currently supported.",
input->type);
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
index 52be089349..55df897180 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
@@ -88,10 +88,10 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
TEST(TransposeConvOpModelTest, TwoFiltersTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1);
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18});
m.PopulateTensor<float>(
m.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
@@ -117,10 +117,10 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) {
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
TEST(TransposeConvOpModelTest, PaddingValidTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1);
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18});
m.PopulateTensor<float>(
m.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
@@ -171,10 +171,10 @@ TEST(TransposeConvOpModelTest, StrideValidTest) {
// [1, 2, 2, 1 ],
// "VALID")
TEST(TransposeConvOpModelTest, MultiChannelTest) {
- TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2);
+ TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
- m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
+ 8, 10, 12, 14, 16, 18});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
m.Invoke();
diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc
index 337bc144b9..79ef0a7c56 100644
--- a/tensorflow/contrib/lite/kernels/transpose_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_test.cc
@@ -51,21 +51,21 @@ void RunTestPermutation(const std::vector<int>& shape,
reversed_perms[k] = k;
}
- // Make input and output dims (i.e. reversed shape and dest_shape).
- Dims<4> input_dims = GetTensorDims(shape);
- Dims<4> output_dims;
- for (int i = 0; i < 4; i++) {
- output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]];
+ // Make input and output shapes.
+ const RuntimeShape input_shape = GetTensorShape(shape);
+ RuntimeShape output_shape(perms.size());
+ for (int i = 0; i < perms.size(); i++) {
+ output_shape.SetDim(i, input_shape.Dims(perms[i]));
}
- output_dims.strides[0] = 1;
- for (int k = 1; k < 4; k++) {
- output_dims.strides[k] =
- output_dims.strides[k - 1] * output_dims.sizes[k - 1];
+
+ TransposeParams params;
+ params.perm_count = perms.size();
+ for (int i = 0; i < perms.size(); ++i) {
+ params.perm[i] = perms[i];
}
- reference_ops::Transpose<float>(input.data(), input_dims,
- input_transposed->data(), output_dims,
- reversed_perms);
+ reference_ops::Transpose<float>(params, input_shape, input.data(),
+ output_shape, input_transposed->data());
}
TEST(TransposeTest, TestRefOps1D) {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 1c28123a24..89d57e4599 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -21,12 +20,13 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -65,14 +65,30 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensor of size {n_batch, n_output}
+constexpr int kInputActivationStateTensor = 18;
+// Cell state tensor of size {n_batch, n_cell}
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// Temporary tensors
+enum TemporaryTensor {
+ kScratchBuffer = 0,
+ kInputQuantized = 1,
+ kOutputStateQuantized = 2,
+ kCellStateQuantized = 3,
+ kScalingFactors = 4,
+ kProductScalingFactors = 5,
+ kRecoveredCellWeights = 6,
+ kNumTemporaryTensors = 7
+};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
+ auto* scratch_tensor_index = new int();
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -84,7 +100,7 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -236,12 +252,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
@@ -261,14 +278,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
@@ -276,54 +303,137 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
+ // The weights are of consistent type, so it suffices to check one.
+ // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
- // Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
if (use_cifg) {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 3;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
} else {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Input, Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 4;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[kOutputStateQuantized] =
+ *scratch_tensor_index + kOutputStateQuantized;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, kOutputStateQuantized);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[kCellStateQuantized] =
+ *scratch_tensor_index + kCellStateQuantized;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, kCellStateQuantized);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
}
return kTfLiteOk;
}
-// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params =
+ reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -364,94 +474,76 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* output_state_ptr = output_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_forget_weights_ptr, input_to_cell_weights_ptr,
- input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr,
- recurrent_to_output_weights_ptr, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, output_state_ptr,
- cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, output_ptr_batch);
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Copy out the LSTM specific params so they can be passed in the function.
+ TfLiteLSTMParams lstm_params;
+ lstm_params.activation = params->activation;
+ lstm_params.cell_clip = params->cell_clip;
+ lstm_params.proj_clip = params->proj_clip;
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return lstm_eval::EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
}
return kTfLiteOk;
}
-
} // namespace unidirectional_sequence_lstm
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index 5881ced7c7..c97b0fdd61 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
// Unit test for TFLite Sequential LSTM op.
-#include <iomanip>
#include <memory>
#include <vector>
@@ -37,7 +36,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip,
float proj_clip,
- const std::vector<std::vector<int>>& input_shapes)
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weights_type = TensorType_FLOAT32)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@@ -48,31 +48,31 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
- input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_input_weights_ = AddInput(weights_type);
}
- input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_forget_weights_ = AddInput(weights_type);
+ input_to_cell_weights_ = AddInput(weights_type);
+ input_to_output_weights_ = AddInput(weights_type);
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
- recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_input_weights_ = AddInput(weights_type);
}
- recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_forget_weights_ = AddInput(weights_type);
+ recurrent_to_cell_weights_ = AddInput(weights_type);
+ recurrent_to_output_weights_ = AddInput(weights_type);
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
- cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_input_weights_ = AddInput(weights_type);
}
- cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_forget_weights_ = AddInput(weights_type);
+ cell_to_output_weights_ = AddInput(weights_type);
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
@@ -89,7 +89,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
- projection_weights_ = AddInput(TensorType_FLOAT32);
+ projection_weights_ = AddInput(weights_type);
if (use_projection_bias) {
projection_bias_ = AddInput(TensorType_FLOAT32);
} else {
@@ -100,15 +100,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
+ /*is_variable=*/true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
- .Union());
+ SetBuiltinOp(
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions,
+ CreateUnidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
+ .Union());
BuildInterpreter(input_shapes);
}
@@ -180,24 +187,9 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void SetInput(int offset, float* begin, float* end) {
- PopulateTensor(input_, offset, begin, end);
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -208,7 +200,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int num_batches() { return n_batch_; }
int sequence_length() { return sequence_length_; }
- private:
+ protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
@@ -232,9 +224,10 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
+
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -243,7 +236,183 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int sequence_length_;
};
-TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+// The hybrid model has quantized weights.
+class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
+ public:
+ HybridUnidirectionalLSTMOpModel(
+ int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
+ bool use_cifg, bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : UnidirectionalLSTMOpModel(
+ n_batch, n_input, n_cell, n_output, sequence_length, use_cifg,
+ use_peephole, use_projection_weights, use_projection_bias,
+ cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> lstm_input_;
+ // LSTM output is stored as num_batch x num_outputs vector.
+ std::vector<std::vector<float>> lstm_golden_output_;
+
+ // Compares output up to tolerance to the result of the lstm given the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ // Feed the whole sequence as input.
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+ }
+
+ lstm->Invoke();
+
+ const int num_outputs = lstm->num_outputs();
+ EXPECT_GT(num_outputs, 0);
+ std::vector<float> expected;
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ }
+
+ EXPECT_THAT(lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+};
+
+class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524};
+ input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113, -0.29909778};
+ input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212};
+ input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077,
+ -0.1556896, 0.19487578};
+ input_gate_bias_ = {0., 0., 0., 0.};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_input_weights_ = {
+ -0.0063535, -0.2042388, 0.31454784, -0.35746509,
+ 0.28902304, 0.08183324, -0.16555229, 0.02286911,
+ -0.13566875, 0.03034258, 0.48091322, -0.12528998,
+ 0.24077177, -0.51332325, -0.33502164, 0.10629296};
+
+ recurrent_to_cell_weights_ = {
+ -0.3407414, 0.24443203, -0.2078532, 0.26320225,
+ 0.05695659, -0.00123841, -0.4744786, -0.35869038,
+ -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064};
+
+ recurrent_to_forget_weights_ = {
+ -0.48684245, -0.06655136, 0.42224967, 0.2112639,
+ 0.27654213, 0.20864892, -0.07646349, 0.45877004,
+ 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004};
+
+ recurrent_to_output_weights_ = {
+ 0.43385774, -0.17194885, 0.2718237, 0.09215671,
+ 0.24107647, -0.39835793, 0.18212086, 0.01301402,
+ 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
+ }
+};
+
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -252,9 +421,11 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
- /*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -279,79 +450,138 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
- lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
- -0.34550029, 0.04266912, -0.15680569,
- -0.34856534, 0.43890524});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
- -0.20583314, 0.44344562, 0.22077113,
- -0.29909778});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
- -0.31343272, -0.40032279, 0.44781327,
- 0.01387155, -0.35593212});
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
- 0.40525138, 0.44272184, 0.03897077, -0.1556896,
- 0.19487578});
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
+
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
- lstm.SetInputGateBias({0., 0., 0., 0.});
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- lstm.SetCellBias({0., 0., 0., 0.});
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
- lstm.SetRecurrentToInputWeights(
- {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
- -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
- -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
- lstm.SetRecurrentToCellWeights(
- {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
- -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
- -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
- lstm.SetRecurrentToForgetWeights(
- {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
- -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
- 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+ });
- lstm.SetRecurrentToOutputWeights(
- {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
- 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
- -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- // Input should have n_input * sequence_length many values.
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
- -0.15358765, -0.03716109, 0.12507336,
- 0.41193449, -0.20860538, -0.15053082,
- 0.09120187, 0.24278517, -0.12222792};
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- float* batch0_start = lstm_input;
- float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
+ /*tolerance=*/0.0157651);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
+ 0.05100781, 0.04717243, 0.48944736,
+ -0.38535351, -0.17212132};
- lstm.Invoke();
+ input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698,
+ 0.24407166, 0.33826375};
- float* golden_start = lstm_golden_output;
- float* golden_end =
- golden_start + lstm.num_outputs() * lstm.sequence_length();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
-}
+ input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_cell_weights_ = {
+ 0.54066205, -0.32668582, -0.43562764, -0.56094903,
+ 0.42957711, 0.01841056, -0.32764608, -0.33027974,
+ -0.10826075, 0.20675004, 0.19069612, -0.03026325,
+ -0.54532051, 0.33003211, 0.44901288, 0.21193194};
+
+ recurrent_to_forget_weights_ = {
+ -0.13832897, -0.0515101, -0.2359007, -0.16661474,
+ -0.14340827, 0.36986142, 0.23414481, 0.55899,
+ 0.10798943, -0.41174671, 0.17751795, -0.34484994,
+ -0.35874045, -0.11352962, 0.27268326, 0.54058349};
+
+ recurrent_to_output_weights_ = {
+ 0.41613156, 0.42610586, -0.16495961, -0.5663873,
+ 0.30579174, -0.05115908, -0.33941799, 0.23364776,
+ 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987};
+
+ cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
+ 0.31544167};
+ cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
+ -0.77109635};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302}};
+ }
+};
-TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -360,9 +590,11 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
- /*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -387,73 +619,690 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
- lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
- 0.04717243, 0.48944736, -0.38535351,
- -0.17212132});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
- -0.3633365, -0.22755712, 0.28253698, 0.24407166,
- 0.33826375});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
- -0.09426838, -0.44257352, 0.54939759,
- 0.01533556, 0.42751634});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.SetCellBias({0., 0., 0., 0.});
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
- lstm.SetRecurrentToCellWeights(
- {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
- 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
- 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
- 0.21193194});
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- lstm.SetRecurrentToForgetWeights(
- {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
- 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
- -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- lstm.SetRecurrentToOutputWeights(
- {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
- -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
- 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
- lstm.SetCellToForgetWeights(
- {0.47485286, -0.51955009, -0.24458408, 0.31544167});
- lstm.SetCellToOutputWeights(
- {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
- -0.05163646, -0.42312205, -0.01218222,
- 0.24201041, -0.08124574, -0.358325,
- -0.04621704, 0.21641694, -0.06471302};
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+ });
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- float* batch0_start = lstm_input;
- float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInput(0, batch0_start, batch0_end);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.Invoke();
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
- float* golden_start = lstm_golden_output;
- float* golden_end =
- golden_start + lstm.num_outputs() * lstm.sequence_length();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
-TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {
+ 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
+
+ input_to_forget_weights_ = {
+ -0.0018401089, -0.004852237, 0.03698424, 0.014181704,
+ 0.028273236, -0.016726194, -0.05249759, -0.10204261,
+ 0.00861066, -0.040979505, -0.009899187, 0.01923892,
+ -0.028177269, -0.08535103, -0.14585495, 0.10662567,
+ -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
+ 0.0814421, -0.12257899, -0.033945758, -0.031303465,
+ 0.045630626, 0.06843887, -0.13492945, -0.012480007,
+ -0.0811829, -0.07224499, -0.09628791, 0.045100946,
+ 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997,
+ 0.052625068, 0.12784666, 0.07077897, 0.025725935,
+ 0.04165009, 0.07241905, 0.018668644, -0.037377294,
+ -0.06277783, -0.08833636, -0.040120605, -0.011405586,
+ -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839,
+ 0.13396506, -0.08402166, -0.01901462, -0.044678304,
+ -0.07720565, 0.014350063, -0.11757958, -0.0652038,
+ -0.08185733, -0.076754324, -0.092614375, 0.10405491,
+ 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447,
+ -0.054523353, 0.02582715, 0.02327355, -0.011857179,
+ -0.0011980024, -0.034641717, -0.026125094, -0.17582615,
+ -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
+ -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
+
+ input_to_cell_weights_ = {
+ -0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
+
+ input_to_output_weights_ = {
+ -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
+
+ input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
+ 0.053110216, -0.06928846, -0.13942584, -0.11816189,
+ 0.19483899, 0.03652339, -0.10250295, 0.036714908,
+ -0.18426876, 0.036065217, 0.21810818, 0.02383196,
+ -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
+
+ forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739};
+
+ cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027};
+
+ output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
+ 0.027195795, 0.35373217, -0.018957434, 0.008907322,
+ -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
+ 0.040952638, 0.3147856, 0.08225149, -0.057416286,
+ -0.14995944, -0.008040261, 0.13208859, 0.029760877};
+
+ recurrent_to_input_weights_ = {
+ -0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447};
+
+ recurrent_to_cell_weights_ = {
+ -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
+
+ recurrent_to_forget_weights_ = {
+ -0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027};
+
+ recurrent_to_output_weights_ = {
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812,
+ -0.01255415, -0.0026479573, -0.08196161, -0.054914974,
+ -0.0046604523, -0.029587349, -0.044576716, -0.07480124,
+ -0.082868785, 0.023254942, 0.027502948, -0.0039728214,
+ -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307,
+ -0.08829125, -0.005139627, -0.08989442, -0.0555066,
+ 0.13596267, -0.025062224, -0.048351806, -0.03850004,
+ 0.07266485, -0.022414139, 0.05940088, 0.075114764,
+ 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916,
+ 0.014416728, 0.043229222, 0.034178585, -0.07530371,
+ 0.035837382, -0.085607, -0.007721233, -0.03287832,
+ -0.043848954, -0.06404588, -0.06632928, -0.073643476,
+ 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091,
+ -0.030063879, 0.008801774, -0.023021035, -0.019558564,
+ 0.05158114, -0.010947698, -0.011825728, 0.0075720972,
+ 0.0699727, -0.0039981045, 0.069350146, 0.08799282,
+ 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699,
+ -0.00924166, 0.0046702605, -0.036598757, -0.08811812,
+ 0.10522024, -0.032441203, 0.008176899, -0.04454919,
+ 0.07058152, 0.0067963637, 0.039206743, 0.03259838,
+ 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724,
+ 0.036879618, 0.043357447, 0.028362012, -0.05908629,
+ 0.0059240665, -0.04995891, -0.019187413, 0.0276265,
+ -0.01628143, 0.0025863599, 0.08800015, 0.035250366,
+ -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454,
+ -0.009660886, 0.019076364, 0.018299393, -0.046004917,
+ 0.08891175, 0.0431396, -0.026327137, -0.051502608,
+ 0.08979574, -0.051670972, 0.04940282, -0.07491107,
+ -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988,
+ -0.035936575, -0.011681591, 0.064818054, 0.0073146066,
+ -0.021745546, -0.043124277, -0.06471268, -0.07053354,
+ -0.029321948, -0.05330136, 0.016933719, -0.053782392,
+ 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434,
+ -0.07924483, 0.06936997, 0.0034815092, -0.007305279,
+ -0.037325785, -0.07251102, -0.033633437, -0.08677009,
+ 0.091591336, -0.14165086, 0.021752775, 0.019683983,
+ 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828,
+ 0.1183656, -0.0010731248, -0.023590032, -0.072285876,
+ -0.0724771, -0.026382286, -0.0014920527, 0.042667855,
+ 0.0018776858, 0.02986552, 0.009814309, 0.0733756,
+ 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798,
+ -0.010036754, 0.02576849, -0.08307328, 0.010112348,
+ 0.042521734, -0.05869831, -0.071689695, 0.03876447,
+ -0.13275425, -0.0352966, -0.023077697, 0.10285965,
+ 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767,
+ -0.08271222, -0.0030240538, -0.016368777, 0.1070414,
+ 0.042672627, 0.013456989, -0.0437609, -0.022309763,
+ 0.11576483, 0.04108048, 0.061026827, -0.0190714,
+ -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893,
+ -0.023771819, -0.01965048, 0.007955471, -0.043740474,
+ 0.03346837, -0.10549954, 0.090567775, 0.042013682,
+ -0.03176985, 0.12569028, -0.02421228, -0.029526481,
+ 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367,
+ -0.06861939, -0.021256343, -0.041093912, -0.06669611,
+ 0.035498552, 0.021757556, -0.09302526, -0.015403468,
+ -0.06614931, -0.051798206, -0.013874718, 0.03630673,
+ 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424,
+ -0.020674974, -0.03944324, -0.008110165, -0.11113267,
+ 0.08484226, 0.043586485, 0.040582247, 0.0968012,
+ -0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
+ 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844,
+ -0.11768019, 0.085926116, -0.08251791, -0.045081906,
+ 0.0948852, 0.068401024, 0.024856757, 0.06978981,
+ -0.057309967, -0.012775832, -0.0032452994, 0.01977615,
+ -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ };
+
+ cell_to_input_weights_ = {
+ 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
+
+ cell_to_forget_weights_ = {
+ -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
+
+ cell_to_output_weights_ = {
+ 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
+
+ projection_weights_ = {
+ -0.009802181, 0.09401916, 0.0717386, -0.13895074,
+ 0.09641832, 0.060420845, 0.08539281, 0.054285463,
+ 0.061395317, 0.034448683, -0.042991187, 0.019801661,
+ -0.16840284, -0.015726732, -0.23041931, -0.024478018,
+ -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914,
+ 0.16169067, 0.22465782, -0.03993472, -0.004017731,
+ 0.08633481, -0.28869787, 0.08682067, 0.17240396,
+ 0.014975425, 0.056431185, 0.031037588, 0.16702051,
+ 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434,
+ 0.014777949, -0.20203483, 0.094781205, 0.19100232,
+ 0.13987629, -0.036132768, -0.06426278, -0.05108664,
+ 0.13221376, 0.009441198, -0.16715929, 0.15859416,
+ -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123,
+ 0.0046453946, 0.050794356, 0.10770313, -0.20790008,
+ -0.07149004, -0.11425117, 0.008225835, -0.035802525,
+ 0.14374903, 0.15262283, 0.048710253, 0.1847461,
+ -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
+ 0.016261552, 0.022461696, 0.12689082, -0.043589946,
+ -0.12035478, -0.08361797, -0.050666027, -0.1248618,
+ -0.1275799, -0.071875185, 0.07377272, 0.09944291,
+ -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444,
+ 0.041546922, -0.20424393, 0.06907816, 0.050412357,
+ 0.00724631, 0.039827548, 0.12449835, 0.10747581,
+ 0.13708383, 0.09134148, -0.12617786, -0.06428341,
+ 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063,
+ 0.022913318, -0.042050496, 0.16842307, -0.060597885,
+ 0.10531834, -0.06411776, -0.07451711, -0.03410368,
+ -0.13393489, 0.06534304, 0.003620307, 0.04490757,
+ 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227,
+ -0.024575593, -0.036445823, 0.07155557, 0.009672501,
+ -0.02328883, 0.009533515, -0.03606021, -0.07421458,
+ -0.028082801, -0.2678904, -0.13221288, 0.18419984,
+ -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874,
+ 0.14494097, -0.12522776, -0.098633975, -0.10766018,
+ -0.08317623, 0.08594209, 0.07749552, 0.039474737,
+ 0.1776665, -0.07409566, -0.0477268, 0.29323658,
+ 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189,
+ 0.10034707, 0.045594677, 0.0635285, -0.0715442,
+ -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
+ -0.009525053, 0.006585689, -0.24567553, -0.09450807,
+ 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363,
+ 0.067035615, 0.19271925, -0.0032889997, -0.043264326,
+ 0.09663576, -0.057112187, -0.10100678, 0.0628376,
+ 0.04447668, 0.017961001, -0.10094388, -0.10190601,
+ 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151,
+ -0.1947263, 0.02251204, 0.11216432, -0.10307853,
+ 0.17351969, -0.039091777, 0.08066188, -0.00561982,
+ 0.12633002, 0.11335965, -0.0088127935, -0.019777594,
+ 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895,
+ -0.07468996, -0.0855457, 0.099339016, -0.07580735,
+ -0.13775392, 0.08434318, 0.08330512, -0.12131499,
+ 0.031935584, 0.09180414, -0.08876437, -0.08049874,
+ 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513,
+ 0.04331237, 0.04299654, -0.036394123, -0.12915532,
+ 0.09793732, 0.07512415, -0.11319543, -0.032502122,
+ 0.15661901, 0.07671967, -0.005491124, -0.19379048,
+ -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725,
+ -0.09334311, 0.15026465, -0.15493552, -0.057762887,
+ -0.11604192, -0.262013, -0.01391798, 0.012185008,
+ 0.11156489, -0.07483202, 0.06693364, -0.26151478,
+ 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075,
+ 0.030635102, 0.010969227, 0.11109743, 0.010919218,
+ 0.027526086, 0.13519906, 0.01891392, -0.046839405,
+ -0.040167913, 0.017953383, -0.09700955, 0.0061885654,
+ -0.07000971, 0.026893595, -0.038844477, 0.14543656};
+
+ lstm_input_ = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
+ 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
+ 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
+ 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
+ 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
+ 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
+ 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
+ };
+
+ lstm_golden_output_ = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -461,8 +1310,9 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
const int sequence_length = 4;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
- /*use_peephole=*/true, /*use_projection_weights=*/true,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
@@ -489,590 +1339,99 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
- lstm.SetInputToInputWeights(
- {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
- 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
- -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
- -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
- -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
- -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
- -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
- 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
- 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
- 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
- -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
- 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
- -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
- -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
- -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
- 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
- -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
- -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
- -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
- -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
-
- lstm.SetInputToForgetWeights(
- {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
- -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
- -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
- 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
- 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
- -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
- -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
- 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
- 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
- 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
- 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
- -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
- 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
- -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
- -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
- 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
- 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
- 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
- -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
- 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
-
- lstm.SetInputToCellWeights(
- {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
- -0.043528453, 0.043018587, -0.049152344, -0.12418144,
- -0.078985475, -0.07596889, 0.019484362, -0.11434962,
- -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
- -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
- 0.10665918, -0.032036792, -0.08505916, -0.10843358,
- -0.13002433, -0.036816437, -0.02130134, -0.016518239,
- 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
- -0.10652836, -0.1037554, -0.13056071, -0.03266643,
- -0.033702414, -0.006473424, -0.04611692, 0.014419339,
- -0.025174323, 0.0396852, 0.081777506, 0.06157468,
- 0.10210095, -0.009658194, 0.046511717, 0.03603906,
- 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
- 0.053568836, 0.06408714, 0.12835667, -0.008714329,
- -0.20211966, -0.12093674, 0.029450472, 0.2849013,
- -0.029227901, 0.1164364, -0.08560263, 0.09941786,
- -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
- -0.09720865, -0.11193351, -0.029155117, -0.017936034,
- -0.009768936, -0.04223324, -0.036159635, 0.06505112,
- -0.021742892, -0.023377212, -0.07221364, -0.06430552,
- 0.05453865, 0.091149814, 0.06387331, 0.007518393,
- 0.055960953, 0.069779344, 0.046411168, 0.10509911,
- 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
- 0.056955688, 0.06555285, 0.050801456, -0.009862683,
- 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
-
- lstm.SetInputToOutputWeights(
- {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
- -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
- 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
- -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
- -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
- 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
- -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
- -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
- -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
- -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
- 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
- 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
- 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
- -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
- 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
- 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
- -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
- 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
- -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
- -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
-
- lstm.SetInputGateBias(
- {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
- -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
- -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
- 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
-
- lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
- 0.11098921, 0.15378423, 0.09263801, 0.09790885,
- 0.09508917, 0.061199076, 0.07665568, -0.015443159,
- -0.03499149, 0.046190713, 0.08895977, 0.10899629,
- 0.40694186, 0.06030037, 0.012413437, -0.06108739});
-
- lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
- -0.1483596, -0.10639995, -0.091433935, 0.058573797,
- -0.06809782, -0.07889636, -0.043246906, -0.09829136,
- -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
- 0.016178843, 0.1749513, 0.13975595, 0.92058027});
-
- lstm.SetOutputGateBias(
- {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
- 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
- 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
- -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
-
- lstm.SetRecurrentToInputWeights(
- {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
- -0.11585556, 0.02557986, -0.13446963, -0.035785314,
- -0.01244275, 0.025961924, -0.02337298, -0.044228926,
- -0.055839065, -0.046598054, -0.010546039, -0.06900766,
- 0.027239809, 0.022582639, -0.013296484, -0.05459212,
- 0.08981, -0.045407712, 0.08682226, -0.06867011,
- -0.14390695, -0.02916037, 0.000996957, 0.091420636,
- 0.14283475, -0.07390571, -0.06402044, 0.062524505,
- -0.093129106, 0.04860203, -0.08364217, -0.08119002,
- 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
- -0.13732095, 0.012405723, -0.07551853, 0.06343048,
- 0.12162708, -0.031923793, -0.014335606, 0.01790974,
- -0.10650317, -0.0724401, 0.08554849, -0.05727212,
- 0.06556731, -0.042729504, -0.043227166, 0.011683251,
- -0.013082158, -0.029302018, -0.010899579, -0.062036745,
- -0.022509435, -0.00964907, -0.01567329, 0.04260106,
- -0.07787477, -0.11576462, 0.017356863, 0.048673786,
- -0.017577527, -0.05527947, -0.082487635, -0.040137455,
- -0.10820036, -0.04666372, 0.022746278, -0.07851417,
- 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
- 0.08944216, -0.0685835, 0.010513544, 0.07228705,
- 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
- 0.040414046, -0.1380399, 0.094208956, -0.05722982,
- 0.012092817, -0.04989123, -0.086576, -0.003399834,
- -0.04696032, -0.045747425, 0.10091314, 0.048676282,
- -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
- 0.09504992, 0.041799378, -0.049185462, -0.031518843,
- -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
- -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
- -0.10167381, 0.042500053, -0.01447153, 0.06464186,
- -0.017142897, 0.03312627, 0.009205989, 0.024138335,
- -0.011337001, 0.035530265, -0.010912711, 0.0706555,
- -0.005894094, 0.051841937, -0.1401738, -0.02351249,
- 0.0365468, 0.07590991, 0.08838724, 0.021681072,
- -0.10086113, 0.019608743, -0.06195883, 0.077335775,
- 0.023646897, -0.095322326, 0.02233014, 0.09756986,
- -0.048691444, -0.009579111, 0.07595467, 0.11480546,
- -0.09801813, 0.019894179, 0.08502348, 0.004032281,
- 0.037211012, 0.068537936, -0.048005626, -0.091520436,
- -0.028379958, -0.01556313, 0.06554592, -0.045599163,
- -0.01672207, -0.020169014, -0.011877351, -0.20212261,
- 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
- -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
- 0.015963363, 0.00871737, 0.060130805, 0.028611384,
- 0.10109069, -0.015060172, -0.07894427, 0.06401885,
- 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
- 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
- 0.019899689, 0.006106124, -0.027092824, 0.0786356,
- 0.05052217, -0.058925, -0.011402121, -0.024987547,
- -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
- -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
- -0.033664223, -0.07978348, -0.025200296, -0.017207067,
- -0.058403496, -0.055697463, 0.005798788, 0.12965427,
- -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
- 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
- 0.013806405, -0.017858358, -0.01008298, -0.07700066,
- -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
- 0.062634714, -0.02338735, -0.039547626, -0.02050681,
- 0.03385117, -0.083611414, 0.002862572, -0.09421313,
- 0.058618143, -0.08598433, 0.00972939, 0.023867095,
- -0.053934585, -0.023203006, 0.07452513, -0.048767887,
- -0.07314807, -0.056307215, -0.10433547, -0.06440842,
- 0.04328182, 0.04389765, -0.020006588, -0.09076438,
- -0.11652589, -0.021705797, 0.03345259, -0.010329105,
- -0.025767034, 0.013057034, -0.07316461, -0.10145612,
- 0.06358255, 0.18531723, 0.07759293, 0.12006465,
- 0.1305557, 0.058638252, -0.03393652, 0.09622831,
- -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
- -0.005644518, 0.06857898, -0.12598175, -0.035084512,
- 0.03156317, -0.12794146, -0.031963028, 0.04692781,
- 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
- 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
- 0.08184801, -0.019164372, 0.06791302, 0.034257166,
- -0.10307039, 0.021943003, 0.046745934, 0.0790918,
- -0.0265588, -0.007824208, 0.042546265, -0.00977924,
- -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
- -0.014512694, -0.08251313, 0.08861942, 0.13589665,
- 0.026351685, 0.012641483, 0.07466548, 0.044301085,
- -0.045414884, -0.051112458, 0.03444247, -0.08502782,
- -0.04106223, -0.028126027, 0.028473156, 0.10467447});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
- 0.14811787, 0.10826372, 0.09471067, 0.03987225,
- -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
- 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
- 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
- -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
- -0.06193199, 0.055729095, 0.03736828, 0.020123724,
- 0.061878487, -0.04729229, 0.034919553, -0.07585433,
- -0.04421272, -0.044019096, 0.085488975, 0.04058006,
- -0.06890133, -0.030951202, -0.024628663, -0.07672815,
- 0.034293607, 0.08556707, -0.05293577, -0.033561368,
- -0.04899627, 0.0241671, 0.015736353, -0.095442444,
- -0.029564252, 0.016493602, -0.035026584, 0.022337519,
- -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
- 0.016435321, -0.03263031, -0.09543275, -0.047392778,
- 0.013454138, 0.028934088, 0.01685226, -0.086110644,
- -0.046250615, -0.01847454, 0.047608484, 0.07339695,
- 0.034546845, -0.04881143, 0.009128804, -0.08802852,
- 0.03761666, 0.008096139, -0.014454086, 0.014361001,
- -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
- -0.06509276, -0.006021153, -0.08570962, -0.1451793,
- 0.060212336, 0.055259194, 0.06974018, 0.049454916,
- -0.027794661, -0.08077226, -0.016179763, 0.1169753,
- 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
- -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
- 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
- -0.05695512, 0.047233116, 0.038937137, -0.06542224,
- 0.014429736, -0.09719407, 0.13908425, -0.05379757,
- 0.012321099, 0.082840554, -0.029899208, 0.044217527,
- 0.059855383, 0.07711018, -0.045319796, 0.0948846,
- -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
- -0.13873616, 0.040668588, 0.034832682, -0.015319203,
- -0.018715994, 0.046002675, 0.0599172, -0.043107376,
- 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
- 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
- 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
- 0.052958444, 0.07558703, 0.04817258, 0.044462286,
- -0.015213451, -0.08783778, -0.0561384, -0.003008196,
- 0.047060397, -0.002058388, 0.03429439, -0.018839769,
- 0.024734668, 0.024614193, -0.042046934, 0.09597743,
- -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
- -0.02558259, -0.022822596, -0.023273505, -0.02464396,
- -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
- 0.04383914, -0.046476185, 0.028658995, 0.060410924,
- 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
- 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
- 0.015898481, 0.021362653, -0.030262267, 0.016587038,
- -0.011442813, 0.041154444, -0.007631438, -0.03423484,
- -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
- 0.02318443, -0.041350313, 0.021485701, -0.10906167,
- -0.028218046, -0.00954771, 0.020531068, -0.11995105,
- -0.03672871, 0.024019798, 0.014255957, -0.05221243,
- -0.00661567, -0.04630967, 0.033188973, 0.10107534,
- -0.014027541, 0.030796422, -0.10270911, -0.035999842,
- 0.15443139, 0.07684145, 0.036571592, -0.035900835,
- -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
- -0.03858649, 0.01849943, 0.13872518, 0.01503974,
- 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
- -0.047401894, 0.03100163, -0.041533746, -0.10430945,
- 0.044574402, -0.01425562, -0.024290353, 0.034563623,
- 0.05866852, 0.023947537, -0.09445152, 0.035450947,
- 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
- 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
- 0.03532124, -0.016341697, 0.09685456, -0.016764693,
- 0.051808182, 0.05875331, -0.04536488, 0.001626336,
- -0.028892258, -0.01048663, -0.009793449, -0.017093895,
- 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
- -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
- -0.01769146, 0.040995963, 0.02235177, -0.060430344,
- 0.11475477, -0.023854522, 0.10071741, 0.0686208,
- -0.014250481, 0.034261297, 0.047418304, 0.08562733,
- -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
- 0.04096551, 0.032249358, -0.08355519, -0.026823482,
- 0.056386515, -0.010401743, -0.028396193, 0.08507674,
- 0.014410365, 0.020995233, 0.17040324, 0.11511526,
- 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
- -0.081302024, 0.017264642, -0.009585969, 0.09491168,
- -0.051313367, 0.054532815, -0.014298593, 0.10657464,
- 0.007076659, 0.10964551, 0.0409152, 0.008275321,
- -0.07283536, 0.07937492, 0.04192024, -0.1075027});
-
- lstm.SetRecurrentToCellWeights(
- {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
- 0.055647098, -0.05713207, -0.05626563, 0.005559383,
- 0.03375411, -0.025757805, -0.088049285, 0.06017052,
- -0.06570978, 0.007384076, 0.035123326, -0.07920549,
- 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
- 0.08089997, 0.05143358, 0.038261272, 0.03339287,
- -0.027673481, 0.044746667, 0.028349208, 0.020090483,
- -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
- -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
- -0.10893326, 0.076739706, -0.08509834, -0.027997585,
- 0.037871376, 0.01449768, -0.09002357, -0.06111149,
- -0.046195522, 0.0422062, -0.005683705, -0.1253618,
- -0.012925729, -0.04890792, 0.06985068, 0.037654128,
- 0.03398274, -0.004781977, 0.007032333, -0.031787455,
- 0.010868644, -0.031489216, 0.09525667, 0.013939797,
- 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
- -0.048885044, -0.12722108, 0.035304096, 0.06554885,
- 0.00972396, -0.039238118, -0.05159735, -0.11329045,
- 0.1613692, -0.03750952, 0.06529313, -0.071974665,
- -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
- 0.02786344, -0.014179351, 0.005264273, 0.14376344,
- 0.015983658, 0.03406988, -0.06939408, 0.040699873,
- 0.02111075, 0.09669095, 0.041345075, -0.08316494,
- -0.07684199, -0.045768797, 0.032298047, -0.041805092,
- 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
- -0.024950314, 0.11574242, 0.04508852, -0.04335324,
- 0.06760663, -0.027437469, 0.07216407, 0.06977076,
- -0.05438599, 0.034033038, -0.028602652, 0.05346137,
- 0.043184172, -0.037189785, 0.10420091, 0.00882477,
- -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
- 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
- 0.04361412, -0.007001822, 0.09631092, -0.06702025,
- -0.042049985, -0.035070654, -0.04103342, -0.10273396,
- 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
- -0.008264958, 0.042035464, 0.05891794, 0.029673764,
- 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
- -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
- -0.04043371, -0.017094059, 0.07229206, -0.023670016,
- -0.052195564, -0.025616996, -0.01520939, 0.045104615,
- -0.007376126, 0.003533447, 0.006570588, 0.056037236,
- 0.12436656, 0.051817212, 0.028532185, -0.08686856,
- 0.11868599, 0.07663395, -0.07323171, 0.03463402,
- -0.050708205, -0.04458982, -0.11590894, 0.021273347,
- 0.1251325, -0.15313013, -0.12224372, 0.17228661,
- 0.023029093, 0.086124025, 0.006445803, -0.03496501,
- 0.028332196, 0.04449512, -0.042436164, -0.026587414,
- -0.006041347, -0.09292539, -0.05678812, 0.03897832,
- 0.09465633, 0.008115513, -0.02171956, 0.08304309,
- 0.071401566, 0.019622514, 0.032163795, -0.004167056,
- 0.02295182, 0.030739572, 0.056506045, 0.004612461,
- 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
- -0.1335546, -0.030136576, 0.11584653, -0.014678886,
- 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
- -0.0329582, 0.07922767, 0.029322514, 0.026405897,
- 0.04207835, -0.07073373, 0.063781224, 0.0859677,
- -0.10925287, -0.07011058, 0.048005477, 0.03438226,
- -0.09606514, -0.006669445, -0.043381985, 0.04240257,
- -0.06955775, -0.06769346, 0.043903265, -0.026784198,
- -0.017840602, 0.024307009, -0.040079936, -0.019946516,
- 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
- 0.15978073, 0.10185836, 0.10298046, -0.015476589,
- -0.039390966, -0.072174534, 0.0739445, -0.1211869,
- -0.0347889, -0.07943156, 0.014809798, -0.12412325,
- -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
- -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
- -0.01514876, -0.056505352, -0.012800942, -0.06994386,
- 0.012962922, -0.031234352, 0.07029052, 0.016418684,
- 0.03618972, 0.055686004, -0.08663945, -0.017404709,
- -0.054761406, 0.029065743, 0.052404847, 0.020238016,
- 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
- 0.06262858, 0.009184685, 0.020785125, -0.043904778,
- -0.0270329, -0.03299152, -0.060088247, -0.015162964,
- -0.001828936, 0.12642565, -0.056757294, 0.013586685,
- 0.09232601, -0.035886683, 0.06000002, 0.05229691,
- -0.052580316, -0.082029596, -0.010794592, 0.012947712,
- -0.036429964, -0.085508935, -0.13127148, -0.017744139,
- 0.031502828, 0.036232427, -0.031581745, 0.023051167,
- -0.05325106, -0.03421577, 0.028793324, -0.034633752,
- -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
- -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
-
- lstm.SetRecurrentToOutputWeights({
- 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
- -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
- -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
- -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
- -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
- -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
- -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
- 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
- -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
- 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
- -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
- -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
- 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
- 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
- -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
- 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
- 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
- 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
- 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
- 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
- -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
- 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
- -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
- 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
- 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
- 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
- -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
- -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
- -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
- -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
- -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
- -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
- 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
- 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
- -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
- 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
- -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
- -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
- -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
- 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
- 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
- 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
- -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
- 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
- -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
- -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
- -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
- -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
- 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
- -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
- 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
- -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
- -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
- -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
- -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
- 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
- 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
- -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
- 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
- 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
- -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
- 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
- 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
- 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
- });
-
- lstm.SetCellToInputWeights(
- {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
- -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
- -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
- 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
-
- lstm.SetCellToForgetWeights(
- {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
- -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
- -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
- 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
-
- lstm.SetCellToOutputWeights(
- {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
- -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
- -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
- 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
-
- lstm.SetProjectionWeights(
- {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
- 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
- -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
- -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
- 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
- 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
- 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
- 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
- -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
- -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
- -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
- 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
- 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
- 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
- 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
- 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
- -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
- 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
- -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
- 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
- -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
- -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
- 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
- -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
- 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
- -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
- -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
- 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
- -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
- -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
- -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
- 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
- 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
- -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
- 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
- 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
- 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
- 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
- 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
- -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
- -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
- 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
- -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
- -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
- 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
- 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
- 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
- -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
- -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
- -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
- 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
- -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
- 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
- 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
- -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
- -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
- -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
- 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
- -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
- -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
- -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
- 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
- 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
- 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
-
- static float lstm_input[][20] = {
- {// Batch0: 4 (input_sequence_size) * 5 (n_input)
- 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
- 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
- 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
-
- {// Batch1: 4 (input_sequence_size) * 5 (n_input)
- 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
- 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
- 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
-
- static float lstm_golden_output[][64] = {
- {// Batch0: 4 (input_sequence_size) * 16 (n_output)
- -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
- -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
- -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
- 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
- -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
- -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
- 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
- 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
- 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
- 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
- -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
- -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
- 0.0286833, 0.00824207, 0.0264887, 0.0305169},
- {// Batch1: 4 (input_sequence_size) * 16 (n_output)
- -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
- -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
- 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
- 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
- -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
- -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
- 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
- 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
- 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
- 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
- -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
- -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
- 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
-
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
- for (int i = 0; i < lstm.sequence_length(); i++) {
- float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
-
- lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
-
- float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
- float* batch1_end = batch1_start + lstm.num_inputs();
- lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
- }
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.Invoke();
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- std::vector<float> expected;
- for (int i = 0; i < lstm.sequence_length(); i++) {
- float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
- float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
- float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
- float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
- expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
- expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
- }
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+ const int sequence_length = 4;
+
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 22c80df19c..744ee7c109 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
#include <cassert>
#include <cmath>
#include <cstdio>
@@ -20,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -32,16 +31,19 @@ namespace ops {
namespace builtin {
namespace unidirectional_sequence_rnn {
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -51,14 +53,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -75,20 +79,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
output_size_array->data[0] = (time_major) ? max_time : batch_size;
@@ -102,7 +98,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = kTfLiteUInt8;
@@ -125,6 +121,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size));
}
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
}
return kTfLiteOk;
}
@@ -187,14 +193,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
return kTfLiteOk;
}
-TfLiteStatus EvalQuantized(const TfLiteTensor* input,
- const TfLiteTensor* input_weights,
- const TfLiteTensor* recurrent_weights,
- const TfLiteTensor* bias,
- const TfLiteSequenceRNNParams* params,
- TfLiteTensor* input_scratch,
- TfLiteTensor* hidden_state_scratch,
- TfLiteTensor* hidden_state, TfLiteTensor* output) {
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
const bool time_major = params->time_major;
const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0];
@@ -218,6 +222,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
reinterpret_cast<int8_t*>(input_scratch->data.uint8);
int8_t* quantized_hidden_state_ptr =
reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
if (time_major) {
// Initialize the pointer to hidden state.
@@ -233,7 +238,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, params->activation, quantized_input_ptr,
- quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ quantized_hidden_state_ptr, scaling_factors_ptr,
+ hidden_state_ptr_batch, output_ptr_batch);
}
} else {
// For each batch
@@ -252,7 +258,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
input_size, num_units, /*batch_size=*/1, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr,
- hidden_state_ptr_batch, output_ptr_batch);
+ scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch);
}
}
}
@@ -267,7 +273,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ // The hidden_state is a variable input tensor that can be modified.
+ TfLiteTensor* hidden_state =
+ const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_weights->type) {
@@ -278,12 +286,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
- return EvalQuantized(input, input_weights, recurrent_weights, bias,
- params, input_quantized, hidden_state_quantized,
- hidden_state, output);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
+ return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
+ input_quantized, hidden_state_quantized,
+ scaling_factors, hidden_state, output);
}
default:
- context->ReportError(context, "Type not currently supported.");
+ context->ReportError(context, "Type %d not currently supported.",
+ input_weights->type);
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index 0adab837b0..6b48e3fff7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -183,7 +183,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
@@ -194,12 +194,14 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
BuildInterpreter({{sequence_len_, batches_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units}});
} else {
BuildInterpreter({{batches_, sequence_len_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units_}});
}
}
@@ -221,14 +223,6 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -273,7 +267,6 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -299,7 +292,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -326,7 +318,6 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
@@ -356,7 +347,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
new file mode 100644
index 0000000000..a7d3a9bc76
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace unpack {
+namespace {
+
+constexpr int kInputTensor = 0;
+
+// Op data for unpack op.
+struct OpData {
+ int num;
+ int axis;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->axis = 0;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
+ TF_LITE_ENSURE(context, NumDimensions(input) > 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) > data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+
+ const TfLiteIntArray* input_shape = input->dims;
+ // Num should be equal to the shape[axis].
+ // Resize outputs. rank will be R - 1.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1);
+ int o = 0;
+ for (int index = 0; index < NumDimensions(input); ++index) {
+ if (index != data->axis) {
+ output_shape->data[o++] = input_shape->data[index];
+ }
+ }
+
+ TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]);
+ for (int i = 0; i < data->num; ++i) {
+ TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, output->type, input->type);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output, copied_output_shape));
+ }
+
+ TfLiteIntArrayFree(output_shape);
+ return kTfLiteOk;
+}
+
+template <typename T>
+void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input, int output_count, int axis) {
+ tflite::UnpackParams op_params;
+ op_params.axis = axis;
+ op_params.num_split = output_count;
+ VectorOfTensors<T> all_outputs(*context, *node->outputs);
+ reference_ops::Unpack<T>(op_params, GetTensorShape(input),
+ GetTensorData<T>(input), **all_outputs.shapes(),
+ all_outputs.data());
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ UnpackImpl<float>(context, node, input, data->num, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace unpack
+
+TfLiteRegistration* Register_UNPACK() {
+ static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare,
+ unpack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/contrib/lite/kernels/unpack_test.cc
new file mode 100644
index 0000000000..4efc92a0fd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack_test.cc
@@ -0,0 +1,225 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class UnpackOpModel : public SingleOpModel {
+ public:
+ UnpackOpModel(const TensorData& input, int axis) {
+ CHECK_LE(axis, input.shape.size());
+ const int num_outputs = input.shape[axis];
+ input_ = AddInput(input);
+ for (int i = 0; i < num_outputs; ++i) {
+ outputs_.push_back(AddOutput(input.type));
+ }
+ SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions,
+ CreatePackOptions(builder_, num_outputs, axis).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+
+ std::vector<std::vector<T>> GetOutputDatas() {
+ std::vector<std::vector<T>> output_datas;
+ for (const int output : outputs_) {
+ std::cerr << "the output is " << output << std::endl;
+ output_datas.push_back(ExtractVector<T>(output));
+ }
+ return output_datas;
+ }
+
+ std::vector<std::vector<int>> GetOutputShapes() {
+ std::vector<std::vector<int>> output_shapes;
+ for (const int output : outputs_) {
+ output_shapes.push_back(GetTensorShape(output));
+ }
+ return output_shapes;
+ }
+
+ private:
+ int input_;
+ std::vector<int> outputs_;
+};
+
+// float32 tests.
+TEST(UnpackOpTest, FloatThreeOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, FloatOneOutput) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+// int32 tests.
+TEST(UnpackOpTest, IntThreeOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, IntOneOutput) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/zeros_like.cc b/tensorflow/contrib/lite/kernels/zeros_like.cc
new file mode 100644
index 0000000000..cce5240a9b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace zeros_like {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ output->type = input->type;
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const int num_elements = NumElements(input);
+ switch (input->type) {
+ case kTfLiteInt64:
+ memset(GetTensorData<int64_t>(output), 0, num_elements * sizeof(int64_t));
+ break;
+ case kTfLiteInt32:
+ memset(GetTensorData<int32_t>(output), 0, num_elements * sizeof(int32_t));
+ break;
+ case kTfLiteFloat32:
+ memset(GetTensorData<float>(output), 0, num_elements * sizeof(float));
+ break;
+ default:
+ context->ReportError(context,
+ "ZerosLike only currently supports int64, int32, "
+ "and float32, got %d.",
+ input->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace zeros_like
+
+TfLiteRegistration* Register_ZEROS_LIKE() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ zeros_like::Prepare, zeros_like::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/zeros_like_test.cc b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
new file mode 100644
index 0000000000..d3382d1d5b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ZerosLikeOpModel : public SingleOpModel {
+ public:
+ explicit ZerosLikeOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput(input);
+ SetBuiltinOp(BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions,
+ CreateZerosLikeOptions(builder_).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ int input() { return input_; }
+ int output() { return output_; }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(ZerosLikeOpModel, ZerosLikeFloat) {
+ ZerosLikeOpModel m({TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 3}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt32) {
+ ZerosLikeOpModel m({TensorType_INT32, {1, 2, 2, 1}});
+ m.PopulateTensor<int32_t>(m.input(), {-2, -1, 0, 3});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
+ ElementsAreArray({0, 0, 0, 0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt64) {
+ ZerosLikeOpModel m({TensorType_INT64, {1, 2, 2, 1}});
+ m.PopulateTensor<int64_t>(m.input(), {-2, -1, 0, 3});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
+ ElementsAreArray({0, 0, 0, 0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}