aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-11-10 10:35:35 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit0b15439f8f0f2d4755587f4096c3ea04cb199d23 (patch)
tree9aa4fc8162bf9b4ee50112a7b85703f70ca4df08 /tensorflow/contrib/lite
parent7ac140a5845553275427162aabd9d54987144b4a (diff)
Internal Change.
PiperOrigin-RevId: 175307445
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/BUILD280
-rw-r--r--tensorflow/contrib/lite/allocation.cc122
-rw-r--r--tensorflow/contrib/lite/allocation.h94
-rw-r--r--tensorflow/contrib/lite/build_def.bzl233
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h164
-rw-r--r--tensorflow/contrib/lite/context.c92
-rw-r--r--tensorflow/contrib/lite/context.h298
-rw-r--r--tensorflow/contrib/lite/context_test.cc74
-rw-r--r--tensorflow/contrib/lite/error_reporter.cc50
-rw-r--r--tensorflow/contrib/lite/error_reporter.h54
-rw-r--r--tensorflow/contrib/lite/interpreter.cc567
-rw-r--r--tensorflow/contrib/lite/interpreter.h376
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc526
-rw-r--r--tensorflow/contrib/lite/java/BUILD164
-rw-r--r--tensorflow/contrib/lite/java/demo/.gitignore9
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle58
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml42
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD43
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD26
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt1001
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java72
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java708
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java35
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java184
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.pngbin0 -> 490 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.pngbin0 -> 3136 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.pngbin0 -> 116 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.pngbin0 -> 320 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.pngbin0 -> 1915 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.pngbin0 -> 611 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.pngbin0 -> 4294 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.pngbin0 -> 952 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.pngbin0 -> 7279 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml50
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml22
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml45
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml24
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml25
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml22
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml21
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml24
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml30
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml19
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml24
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml18
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml32
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml42
-rw-r--r--tensorflow/contrib/lite/java/demo/build.gradle23
-rw-r--r--tensorflow/contrib/lite/java/demo/gradle.properties17
-rw-r--r--tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jarbin0 -> 53636 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties6
-rwxr-xr-xtensorflow/contrib/lite/java/demo/gradlew160
-rw-r--r--tensorflow/contrib/lite/java/demo/gradlew.bat90
-rw-r--r--tensorflow/contrib/lite/java/demo/settings.gradle1
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java76
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java172
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java276
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java71
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java44
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java17
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD70
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc29
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/exception_jni.cc66
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/exception_jni.h50
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc446
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h151
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc242
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h74
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc26
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h36
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/version_script.lds11
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java34
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java221
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java406
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java32
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java105
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD30
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java35
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD408
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h58
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc389
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc323
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc184
-rw-r--r--tensorflow/contrib/lite/kernels/add_test.cc171
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc161
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc267
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc200
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation_test.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc425
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc440
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc289
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc186
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc104
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc248
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc166
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc94
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc307
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc377
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc68
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h54
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc155
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc176
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD359
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h107
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h78
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h65
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h987
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h1916
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h231
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h143
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h167
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h195
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc337
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h113
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h3715
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h138
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc95
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h55
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc108
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h115
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h138
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc165
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h189
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h2455
-rw-r--r--tensorflow/contrib/lite/kernels/internal/round.h39
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h87
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc55
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h116
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc192
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h81
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc87
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h65
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc63
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc109
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm_test.cc101
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc204
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection_test.cc123
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc515
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc1088
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc167
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc127
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h32
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc343
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h28
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc355
-rw-r--r--tensorflow/contrib/lite/kernels/pooling_test.cc161
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc109
-rw-r--r--tensorflow/contrib/lite/kernels/register.h50
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc91
-rw-r--r--tensorflow/contrib/lite/kernels/reshape_test.cc90
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc129
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc117
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc160
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram_test.cc257
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc143
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc146
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth_test.cc102
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc224
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc312
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc183
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h202
-rw-r--r--tensorflow/contrib/lite/model.cc673
-rw-r--r--tensorflow/contrib/lite/model.h165
-rw-r--r--tensorflow/contrib/lite/model_test.cc258
-rw-r--r--tensorflow/contrib/lite/models/smartreply/BUILD15
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc119
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc100
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize.cc105
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc90
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/predict.cc174
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc183
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.cc116
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.h80
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor_test.cc150
-rw-r--r--tensorflow/contrib/lite/models/speech_hotword_model_test.cc115
-rw-r--r--tensorflow/contrib/lite/models/speech_speakerid_model_test.cc114
-rw-r--r--tensorflow/contrib/lite/models/speech_terse_am_model_test.cc127
-rw-r--r--tensorflow/contrib/lite/models/speech_tts_model_test.cc116
-rw-r--r--tensorflow/contrib/lite/models/test_utils.h84
-rw-r--r--tensorflow/contrib/lite/nnapi/BUILD25
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h1916
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc386
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h66
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc108
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.h32
-rw-r--r--tensorflow/contrib/lite/python/BUILD46
-rw-r--r--tensorflow/contrib/lite/python/lite.py199
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py45
-rw-r--r--tensorflow/contrib/lite/schema/BUILD82
-rw-r--r--tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc91
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs346
-rw-r--r--tensorflow/contrib/lite/schema/schema_v0.fbs247
-rw-r--r--tensorflow/contrib/lite/schema/schema_v1.fbs295
-rw-r--r--tensorflow/contrib/lite/schema/schema_v2.fbs303
-rw-r--r--tensorflow/contrib/lite/schema/schema_v3.fbs326
-rw-r--r--tensorflow/contrib/lite/schema/upgrade_schema.py341
-rw-r--r--tensorflow/contrib/lite/schema/upgrade_schema_test.py317
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.cc136
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h84
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena_test.cc91
-rw-r--r--tensorflow/contrib/lite/string.h30
-rw-r--r--tensorflow/contrib/lite/string_util.cc117
-rw-r--r--tensorflow/contrib/lite/string_util.h91
-rw-r--r--tensorflow/contrib/lite/string_util_test.cc117
-rw-r--r--tensorflow/contrib/lite/testdata/0_subgraphs.binbin0 -> 80 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/2_subgraphs.binbin0 -> 172 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/empty_model.binbin0 -> 132 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add.binbin0 -> 652 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add.json46
-rw-r--r--tensorflow/contrib/lite/testdata/no_subgraphs.binbin0 -> 80 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/test_model.binbin0 -> 496 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/test_model_broken.binbin0 -> 432 bytes
-rw-r--r--tensorflow/contrib/lite/testdata/test_model_broken.json62
-rw-r--r--tensorflow/contrib/lite/testdata/two_subgraphs.binbin0 -> 172 bytes
-rw-r--r--tensorflow/contrib/lite/testing/BUILD213
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py1189
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples_report.py125
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc279
-rw-r--r--tensorflow/contrib/lite/testing/message.cc96
-rw-r--r--tensorflow/contrib/lite/testing/message.h82
-rw-r--r--tensorflow/contrib/lite/testing/message_test.cc121
-rw-r--r--tensorflow/contrib/lite/testing/nnapi_example.cc114
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.cc335
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.h74
-rw-r--r--tensorflow/contrib/lite/testing/split.cc42
-rw-r--r--tensorflow/contrib/lite/testing/split.h77
-rw-r--r--tensorflow/contrib/lite/testing/split_test.cc57
-rw-r--r--tensorflow/contrib/lite/testing/test_runner.h124
-rw-r--r--tensorflow/contrib/lite/testing/test_runner_test.cc84
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc208
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h62
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver_test.cc61
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.cc95
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.h42
-rw-r--r--tensorflow/contrib/lite/testing/tokenize_test.cc105
-rw-r--r--tensorflow/contrib/lite/toco/BUILD350
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc318
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.h44
-rw-r--r--tensorflow/contrib/lite/toco/args.h225
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc293
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.h28
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc1570
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.h27
-rw-r--r--tensorflow/contrib/lite/toco/format_port.h77
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc98
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc69
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc223
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc56
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc42
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc57
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc98
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc300
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc326
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h186
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc229
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc170
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc396
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc103
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc120
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc142
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1129
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc467
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc105
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc59
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc60
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc38
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc113
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc40
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc68
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc107
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h55
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc87
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc92
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc122
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc135
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc247
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc196
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc76
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc175
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc51
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc55
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc93
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc49
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc52
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc86
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc63
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc54
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc123
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc97
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD31
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc221
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc73
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc1508
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h34
-rw-r--r--tensorflow/contrib/lite/toco/model.h1372
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc374
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.h43
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto119
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD76
-rw-r--r--tensorflow/contrib/lite/toco/python/toco.i32
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos.py63
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos_test.py96
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.cc85
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.h33
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_wrapper.py35
-rw-r--r--tensorflow/contrib/lite/toco/runtime/common.h26
-rw-r--r--tensorflow/contrib/lite/toco/runtime/types.h32
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD102
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc52
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h101
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc34
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h33
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc151
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h63
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc285
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h82
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc212
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_util.cc197
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_util.h32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD142
-rw-r--r--tensorflow/contrib/lite/toco/tflite/builtin_operator.h74
-rw-r--r--tensorflow/contrib/lite/toco/tflite/custom_operator.h74
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc322
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h76
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc69
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.cc183
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.h49
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import_test.cc141
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc627
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h89
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc370
-rw-r--r--tensorflow/contrib/lite/toco/tflite/simple_operator.h50
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc165
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.h58
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc191
-rw-r--r--tensorflow/contrib/lite/toco/toco.cc119
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc206
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.h35
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto126
-rw-r--r--tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc22
-rw-r--r--tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h34
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.cc227
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h80
-rw-r--r--tensorflow/contrib/lite/toco/toco_port_test.cc58
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc277
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.h50
-rw-r--r--tensorflow/contrib/lite/toco/toco_types.h45
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1552
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h292
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc96
-rw-r--r--tensorflow/contrib/lite/tools/BUILD60
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration.cc46
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration.h38
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration_main.cc98
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration_test.cc87
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.cc43
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.h45
-rw-r--r--tensorflow/contrib/lite/version.h23
365 files changed, 66897 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
new file mode 100644
index 0000000000..c58f77cb11
--- /dev/null
+++ b/tensorflow/contrib/lite/BUILD
@@ -0,0 +1,280 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
+
+exports_files(glob([
+ "testdata/*.bin",
+ "models/testdata/*",
+]))
+
+config_setting(
+ name = "mips",
+ values = {
+ "cpu": "mips",
+ },
+)
+
+config_setting(
+ name = "mips64",
+ values = {
+ "cpu": "mips64",
+ },
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "schema_fbs_version",
+ hdrs = ["version.h"],
+)
+
+# Main library. No ops are included here.
+# TODO(aselle): Resolve problems preventing C99 usage.
+cc_library(
+ name = "context",
+ srcs = ["context.c"],
+ hdrs = ["context.h"],
+)
+
+cc_library(
+ name = "builtin_op_data",
+ hdrs = [
+ "builtin_op_data.h",
+ ],
+)
+
+cc_library(
+ name = "string",
+ hdrs = [
+ "string.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib_platform",
+ ],
+)
+
+# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts.
+cc_library(
+ name = "framework",
+ srcs = [
+ "allocation.cc",
+ "error_reporter.cc",
+ "interpreter.cc",
+ "model.cc",
+ "nnapi_delegate.cc",
+ "optional_debug_tools.cc",
+ "simple_memory_arena.cc",
+ ],
+ hdrs = [
+ "allocation.h",
+ "context.h",
+ "error_reporter.h",
+ "interpreter.h",
+ "model.h",
+ "nnapi_delegate.h",
+ "optional_debug_tools.h",
+ "simple_memory_arena.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":builtin_op_data",
+ ":context",
+ ":schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:gemm_support",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:lib_platform",
+ ],
+)
+
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = [
+ ":framework",
+ ":string",
+ ],
+)
+
+cc_test(
+ name = "string_util_test",
+ size = "small",
+ srcs = ["string_util_test.cc"],
+ deps = [
+ ":framework",
+ ":string_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test main interpreter
+cc_test(
+ name = "interpreter_test",
+ size = "small",
+ srcs = ["interpreter_test.cc"],
+ deps = [
+ ":framework",
+ ":string_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test arena allocator
+cc_test(
+ name = "simple_memory_arena_test",
+ size = "small",
+ srcs = ["simple_memory_arena_test.cc"],
+ deps = [
+ ":framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test model framework.
+cc_test(
+ name = "model_test",
+ size = "small",
+ srcs = ["model_test.cc"],
+ data = [
+ "testdata/0_subgraphs.bin",
+ "testdata/2_subgraphs.bin",
+ "testdata/empty_model.bin",
+ "testdata/test_model.bin",
+ "testdata/test_model_broken.bin",
+ ],
+ deps = [
+ ":framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "context_test",
+ size = "small",
+ srcs = ["context_test.cc"],
+ deps = [
+ ":framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test the serialization of a model with optional tensors.
+
+# Model tests
+
+cc_library(
+ name = "models_test_utils",
+ testonly = 1,
+ hdrs = ["models/test_utils.h"],
+ deps = select({
+ "//tensorflow:android": [],
+ "//conditions:default": [
+ #"//file/base:path",
+ "//tensorflow/core:test",
+ ],
+ }),
+)
+
+cc_test(
+ name = "speech_hotword_model_test",
+ size = "small",
+ srcs = ["models/speech_hotword_model_test.cc"],
+ data = [
+ "models/testdata/speech_hotword_model_in.csv",
+ "models/testdata/speech_hotword_model_out_rank1.csv",
+ "models/testdata/speech_hotword_model_out_rank2.csv",
+ "models/testdata/speech_hotword_model_rank1.tflite",
+ "models/testdata/speech_hotword_model_rank2.tflite",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+gen_selected_ops(
+ name = "speech_speakerid_ops",
+ model = "models/testdata/speech_speakerid_model.tflite",
+)
+
+cc_test(
+ name = "speech_speakerid_model_test",
+ size = "small",
+ srcs = [
+ "models/speech_speakerid_model_test.cc",
+ ":speech_speakerid_ops",
+ ],
+ data = [
+ "models/testdata/speech_speakerid_model.tflite",
+ "models/testdata/speech_speakerid_model_in.csv",
+ "models/testdata/speech_speakerid_model_out.csv",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/tools:mutable_op_resolver",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "speech_terse_am_model_test",
+ size = "small",
+ srcs = ["models/speech_terse_am_model_test.cc"],
+ data = [
+ "models/testdata/speech_terse_am_model.tflite",
+ "models/testdata/speech_terse_am_model_in.csv",
+ "models/testdata/speech_terse_am_model_out.csv",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "speech_tts_model_test",
+ size = "small",
+ srcs = ["models/speech_tts_model_test.cc"],
+ data = [
+ "models/testdata/speech_tts_model.tflite",
+ "models/testdata/speech_tts_model_in.csv",
+ "models/testdata/speech_tts_model_out.csv",
+ ],
+ deps = [
+ ":framework",
+ ":models_test_utils",
+ #"//file/base:path",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
new file mode 100644
index 0000000000..4b322e027d
--- /dev/null
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <cassert>
+#include <cstdarg>
+#include <cstdint>
+#include <cstring>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+
+namespace tflite {
+
+MMAPAllocation::MMAPAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) {
+ mmap_fd_ = open(filename, O_RDONLY);
+ if (mmap_fd_ == -1) {
+ error_reporter_->Report("Could not open '%s'.", filename);
+ return;
+ }
+ struct stat sb;
+ fstat(mmap_fd_, &sb);
+ buffer_size_bytes_ = sb.st_size;
+ mmapped_buffer_ =
+ mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
+ if (mmapped_buffer_ == MAP_FAILED) {
+ error_reporter_->Report("Mmap of '%s' failed.", filename);
+ return;
+ }
+}
+
+MMAPAllocation::~MMAPAllocation() {
+ if (valid()) {
+ munmap(const_cast<void*>(mmapped_buffer_), buffer_size_bytes_);
+ }
+ if (mmap_fd_ != -1) close(mmap_fd_);
+}
+
+const void* MMAPAllocation::base() const { return mmapped_buffer_; }
+
+size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; }
+
+FileCopyAllocation::FileCopyAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter) {
+ // Obtain the file size, using an alternative method that is does not
+ // require fstat for more compatibility.
+ std::unique_ptr<FILE, decltype(&fclose)> file(fopen(filename, "rb"), fclose);
+ if (!file) {
+ error_reporter_->Report("Could not open '%s'.", filename);
+ return;
+ }
+ // TODO(ahentz): Why did you think using fseek here was better for finding
+ // the size?
+ struct stat sb;
+ if (fstat(fileno(file.get()), &sb) != 0) {
+ error_reporter_->Report("Failed to get file size of '%s'.", filename);
+ return;
+ }
+ buffer_size_bytes_ = sb.st_size;
+ std::unique_ptr<char[]> buffer(new char[buffer_size_bytes_]);
+ if (!buffer) {
+ error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.",
+ filename);
+ return;
+ }
+ size_t bytes_read =
+ fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get());
+ if (bytes_read != buffer_size_bytes_) {
+ error_reporter_->Report("Read of '%s' failed (too few bytes read).",
+ filename);
+ return;
+ }
+ copied_buffer_ = std::move(buffer);
+}
+
+FileCopyAllocation::~FileCopyAllocation() {}
+
+const void* FileCopyAllocation::base() const { return copied_buffer_.get(); }
+
+size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; }
+
+MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter) {
+ buffer_ = ptr;
+ buffer_size_bytes_ = num_bytes;
+}
+
+MemoryAllocation::~MemoryAllocation() {}
+
+const void* MemoryAllocation::base() const { return buffer_; }
+
+size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool MemoryAllocation::valid() const { return buffer_ != nullptr; }
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
new file mode 100644
index 0000000000..ee8a7ccd0b
--- /dev/null
+++ b/tensorflow/contrib/lite/allocation.h
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Main abstraction controlling the tflite interpreter.
+// See context.h for the API for defining operations (TfLiteRegistration).
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+
+#include <cstdio>
+#include <cstdlib>
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+
+namespace tflite {
+
+// A memory allocation handle. This could be a mmap or shared memory.
+class Allocation {
+ public:
+ Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {}
+ virtual ~Allocation() {}
+
+ // Base pointer of this allocation
+ virtual const void* base() const = 0;
+ // Size in bytes of the allocation
+ virtual size_t bytes() const = 0;
+ // Whether the allocation is valid
+ virtual bool valid() const = 0;
+
+ protected:
+ ErrorReporter* error_reporter_;
+};
+
+class MMAPAllocation : public Allocation {
+ public:
+ MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
+ virtual ~MMAPAllocation();
+ const void* base() const override;
+ size_t bytes() const override;
+ bool valid() const override;
+
+ protected:
+ // Data required for mmap.
+ int mmap_fd_ = -1; // mmap file descriptor
+ const void* mmapped_buffer_;
+ size_t buffer_size_bytes_ = 0;
+};
+
+class FileCopyAllocation : public Allocation {
+ public:
+ FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
+ virtual ~FileCopyAllocation();
+ const void* base() const override;
+ size_t bytes() const override;
+ bool valid() const override;
+
+ private:
+ // Data required for mmap.
+ std::unique_ptr<const char[]> copied_buffer_;
+ size_t buffer_size_bytes_ = 0;
+};
+
+class MemoryAllocation : public Allocation {
+ public:
+ // Allocates memory with the pointer and the number of bytes of the memory.
+ // The pointer has to remain alive and unchanged until the destructor is
+ // called.
+ MemoryAllocation(const void* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter);
+ virtual ~MemoryAllocation();
+ const void* base() const override;
+ size_t bytes() const override;
+ bool valid() const override;
+
+ private:
+ const void* buffer_;
+ size_t buffer_size_bytes_ = 0;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
new file mode 100644
index 0000000000..e3c9cdd99b
--- /dev/null
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -0,0 +1,233 @@
+"""Generate Flatbuffer binary from json."""
+
+def tflite_copts():
+ """Defines compile time flags."""
+ copts = [
+ "-DFARMHASH_NO_CXX_STRING",
+ ] + select({
+ "//tensorflow:android_arm64": [
+ "-std=c++11",
+ "-O3",
+ ],
+ "//tensorflow:android_arm": [
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ "-std=c++11",
+ "-O3",
+ ],
+ "//tensorflow:android_x86": [
+ "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
+ ],
+ "//tensorflow:ios_x86_64": [
+ "-msse4.1",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_default_optimizations": [],
+ "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
+ })
+
+ return copts
+
+LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds"
+
+def tflite_linkopts_unstripped():
+ """Defines linker flags to reduce size of TFLite binary.
+
+ These are useful when trying to investigate the relative size of the
+ symbols in TFLite.
+
+ Returns:
+ a select object with proper linkopts
+ """
+ return select({
+ "//tensorflow:android": [
+ "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
+ "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
+ "-Wl,--gc-sections", # Eliminate unused code and data.
+ "-Wl,--as-needed", # Don't link unused libs.
+ ],
+ "//tensorflow/contrib/lite:mips": [],
+ "//tensorflow/contrib/lite:mips64": [],
+ "//conditions:default": [
+ "-Wl,--icf=all", # Identical code folding.
+ ],
+ })
+
+def tflite_jni_linkopts_unstripped():
+ """Defines linker flags to reduce size of TFLite binary with JNI.
+
+ These are useful when trying to investigate the relative size of the
+ symbols in TFLite.
+
+ Returns:
+ a select object with proper linkopts
+ """
+ return select({
+ "//tensorflow:android": [
+ "-Wl,--gc-sections", # Eliminate unused code and data.
+ "-Wl,--as-needed", # Don't link unused libs.
+ ],
+ "//tensorflow/contrib/lite:mips": [],
+ "//tensorflow/contrib/lite:mips64": [],
+ "//conditions:default": [
+ "-Wl,--icf=all", # Identical code folding.
+ ],
+ })
+
+def tflite_linkopts():
+ """Defines linker flags to reduce size of TFLite binary."""
+ return tflite_linkopts_unstripped() + select({
+ "//tensorflow:android": [
+ "-s", # Omit symbol table.
+ ],
+ "//conditions:default": [],
+ })
+
+def tflite_jni_linkopts():
+ """Defines linker flags to reduce size of TFLite binary with JNI."""
+ return tflite_jni_linkopts_unstripped() + select({
+ "//tensorflow:android": [
+ "-s", # Omit symbol table.
+ ],
+ "//conditions:default": [],
+ })
+
+
+def tflite_jni_binary(name,
+ copts=tflite_copts(),
+ linkopts=tflite_jni_linkopts(),
+ linkscript=LINKER_SCRIPT,
+ linkshared=1,
+ linkstatic=1,
+ deps=[]):
+ """Builds a jni binary for TFLite."""
+ linkopts = linkopts + [
+ "-Wl,--version-script", # Export only jni functions & classes.
+ linkscript,
+ ]
+ native.cc_binary(
+ name=name,
+ copts=copts,
+ linkshared=linkshared,
+ linkstatic=linkstatic,
+ deps= deps + [linkscript],
+ linkopts=linkopts)
+
+def tf_to_tflite(name, src, options, out):
+ """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
+
+ Args:
+ name: Name of rule.
+ src: name of the input graphdef file.
+ options: options passed to TOCO.
+ out: name of the output flatbuffer file.
+ """
+
+ toco = "//tensorflow/contrib/lite/toco:toco"
+ native.genrule(
+ name = name,
+ srcs=[src, options],
+ outs=[out],
+ cmd = ("$(location %s) " +
+ " --input_file=$(location %s) " +
+ " --output_file=$(location %s) " +
+ " --input_format=TENSORFLOW_GRAPHDEF" +
+ " --output_format=TFLITE" +
+ " `cat $(location %s)`")
+ % (toco, src, out, options),
+ tools= [toco],
+ )
+
+def tflite_to_json(name, src, out):
+ """Convert a TF Lite flatbuffer to JSON.
+
+ Args:
+ name: Name of rule.
+ src: name of the input flatbuffer file.
+ out: name of the output JSON file.
+ """
+
+ flatc = "@flatbuffers//:flatc"
+ schema = "//tensorflow/contrib/lite/schema:schema.fbs"
+ native.genrule(
+ name = name,
+ srcs = [schema, src],
+ outs = [out],
+ cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
+ "$(location %s) --raw-binary --strict-json -t" +
+ " -o /tmp $(location %s) -- $${TMP}.bin &&" +
+ "cp $${TMP}.json $(location %s)")
+ % (src, flatc, schema, out),
+ tools = [flatc],
+ )
+
+def json_to_tflite(name, src, out):
+ """Convert a JSON file to TF Lite's flatbuffer.
+
+ Args:
+ name: Name of rule.
+ src: name of the input JSON file.
+ out: name of the output flatbuffer file.
+ """
+
+ flatc = "@flatbuffers//:flatc"
+ schema = "//tensorflow/contrib/lite/schema:schema_fbs"
+ native.genrule(
+ name = name,
+ srcs = [schema, src],
+ outs = [out],
+ cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
+ "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
+ " -o /tmp $(location %s) $${TMP}.json &&" +
+ "cp $${TMP}.bin $(location %s)")
+ % (src, flatc, schema, out),
+ tools = [flatc],
+ )
+
+def gen_zipped_test_files(name, files):
+ """Generate a zip file of tests by using :generate_examples.
+
+ Args:
+ name: Name of output. We will produce "`name`_files" as a target.
+ files: A list of zip file basenames.
+ """
+ toco = "//tensorflow/contrib/lite/toco:toco"
+ out_files = []
+ for f in files:
+ out_file = name + "/" + f
+ out_files.append(out_file)
+ native.genrule(
+ name = name + "_" + f + ".files",
+ cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
+ + " --zip_to_output " + f +
+ " $(@D) zipped"),
+ outs = [out_file],
+ tools = [
+ ":generate_examples",
+ toco,
+ ],
+ )
+
+ native.filegroup(
+ name = name,
+ srcs = out_files,
+ )
+
+def gen_selected_ops(name, model):
+ """Generate the library that includes only used ops.
+
+ Args:
+ name: Name of the generated library.
+ model: TFLite model to interpret.
+ """
+ out = name + "_registration.cc"
+ tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
+ native.genrule(
+ name = name,
+ srcs = [model],
+ outs = [out],
+ cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)")
+ % (tool, model, out),
+ tools = [tool],
+ )
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
new file mode 100644
index 0000000000..93072bf90b
--- /dev/null
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -0,0 +1,164 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams;
+
+typedef struct { float beta; } TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+} TfLiteLSTMParams;
+
+typedef struct {
+ int new_height;
+ int new_width;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c
new file mode 100644
index 0000000000..c09e838c5c
--- /dev/null
+++ b/tensorflow/contrib/lite/context.c
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/context.h"
+#include <stdio.h>
+#include <string.h>
+
+TfLiteIntArray* TfLiteIntArrayCreate(int size) {
+ TfLiteIntArray* ret =
+ (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size);
+ ret->size = size;
+ return ret;
+}
+
+void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) {
+ printf("%s: length=%d [", s, a->size);
+ if (a->size) printf("%d", a->data[0]);
+ int i = 1;
+ for (; i < a->size; i++) {
+ printf(" %d", a->data[i]);
+ }
+ printf("]\n");
+}
+
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
+ if (a == b) return 1;
+ if (a == NULL || b == NULL) return 0;
+ if (a->size != b->size) return 0;
+ int i = 0;
+ for (; i < a->size; i++)
+ if (a->data[i] != b->data[i]) return 0;
+ return 1;
+}
+
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) {
+ if (!src) return NULL;
+ TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
+ if (ret) {
+ memcpy(ret->data, src->data, src->size * sizeof(int));
+ }
+ return ret;
+}
+
+void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); }
+
+void TfLiteTensorFree(TfLiteTensor* t) {
+ if (t->allocation_type == kTfLiteDynamic && t->data.raw) {
+ free(t->data.raw);
+ }
+ if (t->dims) TfLiteIntArrayFree(t->dims);
+ t->data.raw = NULL;
+ t->dims = NULL;
+}
+
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, TfLiteTensor* tensor) {
+ TfLiteTensorFree(tensor);
+ tensor->type = type;
+ tensor->name = name;
+ tensor->dims = dims;
+ tensor->params = quantization;
+ tensor->data.raw = buffer;
+ tensor->bytes = size;
+ tensor->allocation_type = allocation_type;
+ tensor->allocation = allocation;
+}
+
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
+ if (tensor->allocation_type != kTfLiteDynamic) {
+ return;
+ }
+ if (!tensor->data.raw) {
+ tensor->data.raw = malloc(num_bytes);
+ } else if (num_bytes > tensor->bytes) {
+ tensor->data.raw = realloc(tensor->data.raw, num_bytes);
+ }
+ tensor->bytes = num_bytes;
+}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
new file mode 100644
index 0000000000..41257a53b1
--- /dev/null
+++ b/tensorflow/contrib/lite/context.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of points that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+} TfLiteTensor;
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ int tensors_size;
+ // An tensor of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // TODO(ahentz): we should create a more general mechanism for this sort of
+ // library-global objects.
+ void* gemm_context;
+} TfLiteContext;
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin.
+ void* builtin_data;
+} TfLiteNode;
+
+typedef struct {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API. Note, it is the responsibility of the registration binder to
+ // set this properly.
+ int32_t builtin_code;
+} TfLiteRegistration;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc
new file mode 100644
index 0000000000..d0a104f43d
--- /dev/null
+++ b/tensorflow/contrib/lite/context_test.cc
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/context.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// NOTE: this tests only the TfLiteIntArray part of context.
+// most of context.h is provided in the context of using it with interpreter.h
+// and interpreter.cc, so interpreter_test.cc tests context structures more
+// thoroughly.
+
+TEST(IntArray, TestIntArrayCreate) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(0);
+ TfLiteIntArray* b = TfLiteIntArrayCreate(3);
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+}
+
+TEST(IntArray, TestIntArrayCopy) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(2);
+ a->data[0] = 22;
+ a->data[1] = 24;
+ TfLiteIntArray* b = TfLiteIntArrayCopy(a);
+ ASSERT_NE(a, b);
+ ASSERT_EQ(a->size, b->size);
+ ASSERT_EQ(a->data[0], b->data[0]);
+ ASSERT_EQ(a->data[1], b->data[1]);
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+}
+
+TEST(IntArray, TestIntArrayEqual) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(1);
+ a->data[0] = 1;
+ TfLiteIntArray* b = TfLiteIntArrayCreate(2);
+ b->data[0] = 5;
+ b->data[1] = 6;
+ TfLiteIntArray* c = TfLiteIntArrayCreate(2);
+ c->data[0] = 5;
+ c->data[1] = 6;
+ TfLiteIntArray* d = TfLiteIntArrayCreate(2);
+ d->data[0] = 6;
+ d->data[1] = 6;
+ ASSERT_FALSE(TfLiteIntArrayEqual(a, b));
+ ASSERT_TRUE(TfLiteIntArrayEqual(b, c));
+ ASSERT_TRUE(TfLiteIntArrayEqual(b, b));
+ ASSERT_FALSE(TfLiteIntArrayEqual(c, d));
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+ TfLiteIntArrayFree(c);
+ TfLiteIntArrayFree(d);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc
new file mode 100644
index 0000000000..6ba5384a94
--- /dev/null
+++ b/tensorflow/contrib/lite/error_reporter.cc
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include <cstdarg>
+#include <cstdio>
+
+namespace tflite {
+
+ErrorReporter::~ErrorReporter() {}
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+int StderrReporter::Report(const char* format, va_list args) {
+ return vfprintf(stderr, format, args);
+}
+
+ErrorReporter* DefaultErrorReporter() {
+ static StderrReporter* error_reporter = new StderrReporter;
+ return error_reporter;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
new file mode 100644
index 0000000000..637d456ce7
--- /dev/null
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+// ErrorReporter foo;
+// foo.Report("test %d\n", 5);
+// or
+// va_list args;
+// foo.Report("test %d\n", args); // where args is va_list
+//
+// Sublclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter();
+ virtual int Report(const char* format, va_list args) = 0;
+ int Report(const char* format, ...);
+ int ReportError(void*, const char* format, ...);
+};
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
new file mode 100644
index 0000000000..954e236ac8
--- /dev/null
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -0,0 +1,567 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include <cassert>
+#include <cstdarg>
+#include <cstdint>
+#include <cstring>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+
+namespace {
+
+// Memory allocation tuning
+constexpr const int kDefaultArenaAlignment = 64;
+constexpr const int kDefaultTensorAlignment = 4;
+// std::vector preallocation tuning.
+constexpr const int kSlotsToReserve = 128;
+
+} // namespace
+
+namespace tflite {
+
+Interpreter::Interpreter(ErrorReporter* error_reporter)
+ : arena_(kDefaultArenaAlignment),
+ persistent_arena_(kDefaultArenaAlignment),
+ error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {
+ context_.impl_ = static_cast<void*>(this);
+ context_.ResizeTensor = ResizeTensor;
+ context_.ReportError = ReportError;
+ context_.AddTensors = AddTensors;
+ context_.tensors = nullptr;
+ context_.tensors_size = 0;
+ context_.gemm_context = nullptr;
+ // Reserve some space for the tensors to avoid excessive resizing.
+ tensors_.reserve(kSlotsToReserve);
+ nodes_and_registration_.reserve(kSlotsToReserve);
+ next_allocate_node_id_ = 0;
+ UseNNAPI(false);
+}
+
+Interpreter::~Interpreter() {
+ for (auto& nodeAndReg : nodes_and_registration_) {
+ TfLiteNode& node = nodeAndReg.first;
+ TfLiteIntArrayFree(node.inputs);
+ TfLiteIntArrayFree(node.outputs);
+ TfLiteIntArrayFree(node.temporaries);
+ if (node.builtin_data) free(node.builtin_data);
+ OpFree(nodeAndReg.second, node.user_data);
+ node.builtin_data = nullptr;
+ }
+
+ for (int i = 0; i < context_.tensors_size; i++) {
+ TfLiteTensorFree(&context_.tensors[i]);
+ }
+}
+
+TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
+ TF_LITE_ENSURE_OK(&context_,
+ CheckTensorIndices("inputs", inputs.data(), inputs.size()));
+ inputs_ = std::move(inputs);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
+ TF_LITE_ENSURE_OK(
+ &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
+ outputs_ = std::move(outputs);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
+ const int* indices, int length) {
+ // Making sure kOptionalTensor is not re-defined to something other than -1.
+ static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
+
+ for (int i = 0; i < length; i++) {
+ int index = indices[i];
+ if (index < kOptionalTensor || index >= context_.tensors_size) {
+ ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
+ consistent_ = false;
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
+ int dims_size, size_t* bytes) {
+ // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
+ // MultiplyWithoutOverflow.
+ TF_LITE_ENSURE(&context_, bytes != nullptr);
+ size_t count = 1;
+ for (int k = 0; k < dims_size; k++) count *= dims[k];
+ switch (type) {
+ case kTfLiteFloat32:
+ *bytes = sizeof(float) * count;
+ break;
+ case kTfLiteInt32:
+ *bytes = sizeof(int32_t) * count;
+ break;
+ case kTfLiteUInt8:
+ *bytes = sizeof(uint8_t) * count;
+ break;
+ case kTfLiteInt64:
+ *bytes = sizeof(int64_t) * count;
+ break;
+ default:
+ ReportError(&context_,
+ "Only float32, int32, int64, uint8 supported currently.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() {
+ if (!consistent_) {
+ ReportError(&context_, "AllocateTensors() called on inconsistent model.");
+ return kTfLiteError;
+ }
+ if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) {
+ return kTfLiteOk;
+ }
+ allocs_and_refcounts_.resize(context_.tensors_size);
+
+ int new_next_allocate_node_id = next_allocate_node_id_;
+ invokable_ = false;
+
+ // Allocate graph input nodes.
+ if (next_allocate_node_id_ == 0) {
+ for (int i = 0; i < inputs_.size(); ++i) {
+ int tensor_index = inputs_[i];
+ if (tensor_index == kOptionalTensor) {
+ continue;
+ }
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ // Add 1 to output tensors, so they will not get overwritten.
+ for (int i = 0; i < outputs_.size(); ++i) {
+ allocs_and_refcounts_[outputs_[i]].count++;
+ }
+ }
+
+ // Count references to node input tensors, and resize node-referenced tensors
+ // until we encounter a node that has a dynamic output tensor.
+ for (int k = next_allocate_node_id_; k < nodes_and_registration_.size();
+ k++) {
+ new_next_allocate_node_id++;
+ TfLiteNode& node = nodes_and_registration_[k].first;
+ const TfLiteRegistration& registration = nodes_and_registration_[k].second;
+ if (OpPrepare(registration, &node) == kTfLiteError) {
+ return kTfLiteError;
+ }
+
+ TfLiteIntArray* node_inputs = node.inputs;
+ for (int i = 0; i < node_inputs->size; ++i) {
+ int tensor_index = node_inputs->data[i];
+ if (tensor_index != kOptionalTensor) {
+ allocs_and_refcounts_[node_inputs->data[i]].count++;
+ }
+ }
+
+ // Discontinue if the node has dynamic outputs.
+ bool has_unallocated_dynamic_tensor = false;
+ TfLiteIntArray* node_outputs = node.outputs;
+ for (int i = 0; i < node_outputs->size; ++i) {
+ TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]];
+ if (tensor.allocation_type == kTfLiteDynamic) {
+ has_unallocated_dynamic_tensor = true;
+ break;
+ }
+ }
+ if (has_unallocated_dynamic_tensor) {
+ break;
+ }
+ }
+
+ // Allocate graph persistent outputs, e.g. RNN cell states, etc.
+ for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) {
+ TfLiteNode& node = nodes_and_registration_[k].first;
+
+ // Go through output tensors and allocate the persistent ones first.
+ TfLiteIntArray* node_outputs = node.outputs;
+ for (int i = 0; i < node_outputs->size; ++i) {
+ int tensor_index = node_outputs->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
+ TF_LITE_ENSURE_OK(&context_,
+ persistent_arena_.Allocate(
+ &context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ }
+
+ // Go through the graph in execution order.
+ for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) {
+ TfLiteNode& node = nodes_and_registration_[k].first;
+
+ // First allocate output tensors.
+ TfLiteIntArray* node_outputs = node.outputs;
+ for (int i = 0; i < node_outputs->size; ++i) {
+ int tensor_index = node_outputs->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ // Then the temporaries, in two passes. First allocate them all, them
+ // deallocate them.
+ TfLiteIntArray* node_temporaries = node.temporaries;
+ for (int i = 0; i < node_temporaries->size; ++i) {
+ int tensor_index = node_temporaries->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes,
+ &allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ for (int i = 0; i < node_temporaries->size; ++i) {
+ int tensor_index = node_temporaries->data[i];
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+ allocs_and_refcounts_[tensor_index].count--;
+ if (tensor.allocation_type == kTfLiteArenaRw &&
+ allocs_and_refcounts_[tensor_index].count == 0) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Deallocate(&context_,
+ allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+
+ // Then process the node's inputs.
+ TfLiteIntArray* node_inputs = node.inputs;
+ for (int i = 0; i < node_inputs->size; ++i) {
+ int tensor_index = node_inputs->data[i];
+ if (tensor_index == kOptionalTensor) {
+ continue;
+ }
+ TfLiteTensor& tensor = context_.tensors[tensor_index];
+
+ // Decrease reference count and deallocate if not needed anymore.
+ allocs_and_refcounts_[tensor_index].count--;
+ if (tensor.allocation_type == kTfLiteArenaRw &&
+ allocs_and_refcounts_[tensor_index].count == 0) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.Deallocate(&context_,
+ allocs_and_refcounts_[tensor_index].alloc));
+ }
+ }
+ }
+
+ // Resize the buffer and commit the arena.
+ TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_));
+ TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_));
+
+ // Rewire the tensors to use the underlying arena buffer.
+ for (int i = 0; i < context_.tensors_size; ++i) {
+ TfLiteTensor& tensor = context_.tensors[i];
+ if (tensor.allocation_type == kTfLiteArenaRw) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc,
+ &tensor.data.raw));
+ }
+ if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
+ TF_LITE_ENSURE_OK(
+ &context_,
+ persistent_arena_.ResolveAlloc(
+ &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw));
+ }
+ }
+
+ invokable_ = true;
+ next_allocate_node_id_ = new_next_allocate_node_id;
+ return kTfLiteOk;
+}
+
+namespace {
+TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
+ TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
+ for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
+ return lite;
+}
+} // namespace
+
+TfLiteStatus Interpreter::AllocateTensors() {
+ next_allocate_node_id_ = 0;
+ TF_LITE_ENSURE_OK(&context_, arena_.Clear());
+ TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear());
+ allocs_and_refcounts_.clear();
+ return AllocateTensorsWhoseSizesAreKnown();
+}
+
+TfLiteStatus Interpreter::AddNodeWithParameters(
+ const std::vector<int>& inputs, const std::vector<int>& outputs,
+ const char* init_data, size_t init_data_size, void* builtin_data,
+ const TfLiteRegistration* registration, int* node_index) {
+ invokable_ = false;
+
+ std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
+ free);
+
+ TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(),
+ inputs.size()));
+ TF_LITE_ENSURE_OK(
+ &context_,
+ CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
+
+ if (node_index) *node_index = nodes_and_registration_.size();
+ nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
+ auto& node_and_reg = nodes_and_registration_.back();
+ TfLiteNode& node = node_and_reg.first;
+ if (node.inputs) TfLiteIntArrayFree(node.inputs);
+ if (node.outputs) TfLiteIntArrayFree(node.outputs);
+ if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
+
+ // NOTE, here we are not using move semantics yet, since our internal
+ // representation isn't std::vector, but in the future we would like to avoid
+ // copies, so we want the interface to take r-value references now.
+ node.inputs = convertVectorToTfLiteIntArray(inputs);
+ node.outputs = convertVectorToTfLiteIntArray(outputs);
+ node.temporaries = TfLiteIntArrayCreate(0);
+ if (init_data) {
+ node.user_data = OpInit(*registration, init_data, init_data_size);
+ } else {
+ node.user_data =
+ OpInit(*registration,
+ reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
+ }
+ node.builtin_data = builtin_data_deleter.release();
+ node_and_reg.second = *registration;
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
+ const std::vector<int>& dims) {
+ // TODO(aselle): All bounds checks can be implemented as one-sided bounds
+ // checks by casting to unsigned for efficiency. Profile before doing this.
+
+ TF_LITE_ENSURE(&context_,
+ tensor_index < context_.tensors_size && tensor_index >= 0);
+ invokable_ = false;
+ TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims);
+ return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
+}
+
+TfLiteStatus Interpreter::Invoke() {
+ if (!consistent_) {
+ ReportError(&context_, "Invoke called on model that is not consistent.");
+ return kTfLiteError;
+ }
+ if (!invokable_) {
+ ReportError(&context_, "Invoke called on model that is not ready.");
+ return kTfLiteError;
+ }
+
+ TfLiteStatus status = kTfLiteOk;
+ if (nnapi_delegate_) {
+ if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) {
+ return kTfLiteError;
+ }
+ if (next_allocate_node_id_ == nodes_and_registration_.size()) {
+ TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
+ return kTfLiteOk;
+ } else {
+ // TODO(aselle): In the future, we would like this to be an
+ // automatic tflite CPU fallback.
+ ReportError(&context_,
+ "NNAPI was requested, but dependent sized tensors "
+ "being used.\n");
+ return kTfLiteError;
+ }
+ }
+
+ for (int i = 0; i < nodes_and_registration_.size(); i++) {
+ // Ensure we have allocated up to this node. The point of this is to
+ // allocate as much as possible before running any evaluation, but
+ // dynamic shapes can prevent this from being possible.
+ if (i >= next_allocate_node_id_) {
+ if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) {
+ return kTfLiteError;
+ }
+ }
+ TfLiteNode& node = nodes_and_registration_[i].first;
+ const TfLiteRegistration& registration = nodes_and_registration_[i].second;
+ if (OpInvoke(registration, &node) == kTfLiteError) {
+ status = kTfLiteError;
+ }
+ }
+ return status;
+}
+
+TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context,
+ TfLiteTensor* tensor,
+ TfLiteIntArray* new_size) {
+ // Note here that context->impl_ is recovering the this pointer for an
+ // instance of Interpreter to call into the member function ResizeTensorImpl
+ // (this function is static).
+ return static_cast<Interpreter*>(context->impl_)
+ ->ResizeTensorImpl(tensor, new_size);
+}
+
+void Interpreter::ReportErrorImpl(const char* format, va_list args) {
+ error_reporter_->Report(format, args);
+}
+
+void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ auto* f = static_cast<Interpreter*>(context->impl_);
+ // Note here that context->impl_ is recovering the this pointer for an
+ // instance of Interpreter to call into the member function ReportErrorImpl
+ // (this function is static).
+ f->ReportErrorImpl(format, args);
+ va_end(args);
+}
+
+TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
+ int* first_new_tensor_index) {
+ int base_index = tensors_.size();
+ if (first_new_tensor_index) *first_new_tensor_index = base_index;
+ tensors_.resize(tensors_.size() + tensors_to_add);
+ for (int i = base_index; i < tensors_.size(); i++) {
+ memset(&tensors_[i], 0, sizeof(tensors_[i]));
+ }
+ context_.tensors = tensors_.data();
+ context_.tensors_size = tensors_.size();
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add,
+ int* first_new_tensor_index) {
+ // Note here that context->impl_ is recovering the this pointer for an
+ // instance of Interpreter to call into the member function AddTensors
+ // (this function is static).
+ return static_cast<Interpreter*>(context->impl_)
+ ->AddTensors(tensors_to_add, first_new_tensor_index);
+}
+
+TfLiteStatus Interpreter::SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ const char* buffer, size_t bytes, const Allocation* allocation) {
+ TF_LITE_ENSURE(&context_,
+ tensor_index < context_.tensors_size && tensor_index >= 0);
+ // For most tensors we know exactly how much memory is necessary so we can
+ // ensure the buffer is large enough. However, we need to skip string tensors
+ // because their sizes change with the contents of the individual strings.
+ if (type != kTfLiteString) {
+ size_t required_bytes;
+ TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
+ &required_bytes));
+ TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
+ }
+ invokable_ = false;
+ TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ quantization, const_cast<char*>(buffer), bytes,
+ kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]);
+ return kTfLiteOk;
+}
+
+// Set description of inputs/outputs/data/fptrs for node `node_index`.
+// This variant assumes an external buffer has been allocated of size
+// bytes. The lifetime of buffer must be ensured to be greater or equal
+// to Interpreter.
+TfLiteStatus Interpreter::SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ invokable_ = false;
+ TF_LITE_ENSURE(&context_,
+ tensor_index < context_.tensors_size && tensor_index >= 0);
+ size_t required_bytes = 0;
+ if (type != kTfLiteString) {
+ // These types will be allocated in our arena so we need to record how
+ // many bytes we will need based on the dimensions. String tensors are
+ // allocated dynamically and we can't know ahead of time how much space
+ // they will require.
+ TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
+ &required_bytes));
+ }
+ TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ quantization,
+ /*buffer=*/nullptr, required_bytes,
+ type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
+ nullptr, &context_.tensors[tensor_index]);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
+ TfLiteIntArray* new_size) {
+ // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
+ if (tensor->allocation_type == kTfLiteArenaRw ||
+ tensor->allocation_type == kTfLiteDynamic) {
+ if (tensor->type != kTfLiteString) {
+ size_t bytesRequired;
+ TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
+ new_size->size, &bytesRequired);
+ if (status != kTfLiteOk) {
+ TfLiteIntArrayFree(new_size);
+ return kTfLiteError;
+ }
+ tensor->bytes = bytesRequired;
+ }
+ if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
+ tensor->dims = new_size;
+
+ if (tensor->allocation_type != kTfLiteDynamic) {
+ tensor->data.raw = nullptr;
+ }
+ } else {
+ // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore
+ // of fixed size.
+ TfLiteIntArrayFree(new_size);
+ ReportError(&context_, "Attempting to resize a fixed-size tensor.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+void Interpreter::UseNNAPI(bool enable) {
+ // TODO(aselle): This is a workaround for finding if NNAPI exists.
+ // We also need to make sure getLibraryHandle() is renamed to be NNAPI
+ // prefixed.
+ if (!NNAPIExists()) enable = false;
+ if (!enable) {
+ nnapi_delegate_.reset();
+ } else if (!nnapi_delegate_) {
+ nnapi_delegate_.reset(new NNAPIDelegate);
+ }
+}
+
+void Interpreter::SetNumThreads(int num_threads) {
+ // TODO(ahentz): this forces us to link against gemmlowp even when the ops
+ // don't use it. We should implement some dynamic mechanism for this sort of
+ // library-specific initialization.
+ tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
new file mode 100644
index 0000000000..8bf60e91f7
--- /dev/null
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -0,0 +1,376 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Main abstraction controlling the tflite interpreter.
+// See context.h for the API for defining operations (TfLiteRegistration).
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+
+#include <cstdio>
+#include <cstdlib>
+#include <vector>
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+#include "tensorflow/core/platform/platform.h"
+
+namespace tflite {
+
+// Map statically from a c++ type to a TfLiteType (used below for safe casts).
+template <class T>
+constexpr TfLiteType typeToTfLiteType() {
+ return kTfLiteNoType;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<int>() {
+ return kTfLiteInt32;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<int64_t>() {
+ return kTfLiteInt64;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<float>() {
+ return kTfLiteFloat32;
+}
+template <>
+constexpr TfLiteType typeToTfLiteType<unsigned char>() {
+ return kTfLiteUInt8;
+}
+
+struct ArenaAllocRefCount {
+ ArenaAllocRefCount() : alloc(), count(0) {}
+
+ ArenaAlloc alloc;
+ int count;
+};
+
+// Forward declare since NNAPIDelegate uses Interpreter.
+class NNAPIDelegate;
+
+// An interpreter for a graph of nodes that input and output from tensors.
+// Each node of the graph processes a set of input tensors and produces a
+// set of output Tensors. All inputs/output tensors are referenced by index.
+//
+// Usage:
+//
+// -- Create basic model
+// Interpreter foo(2, 1);
+// foo.SetTensorParametersReadWrite(0, ...);
+// foo.SetTensorParametersReadOnly(1, ...);
+// foo.SetNodeParameters(0, ...)
+//
+// -- Resize input array to 1 length.
+// foo.ResizeInputTensor(0, 1);
+// foo.AllocateTensors();
+// -- Install array data
+// foo.typed_tensor<float>(0)[0] = 3;
+// foo.Invoke();
+// foo.typed_tensor<float>(0)[0] = 4;
+// foo.Invoke();
+// -- Resize input array and set data.
+// foo.ResizeInputTensor(0, 2);
+// foo.AllocateTensors();
+// foo.typed_tensor<float>(0)[0] = 4;
+// foo.typed_tensor<float>(0)[1] = 8;
+// foo.Invoke();
+//
+
+class Interpreter {
+ public:
+ // Instantiate an interpreter. All errors associated with reading and
+ // processing this model will be forwarded to the error_reporter object.
+ //
+ // Note, if error_reporter is nullptr, then a default StderrReporter is
+ // used.
+ explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ ~Interpreter();
+
+ Interpreter(const Interpreter&) = delete;
+ Interpreter& operator=(const Interpreter&) = delete;
+
+ // Functions to build interpreter
+
+ // Provide a list of tensor indexes that are inputs to the model.
+ // Each index is bound check and this modifies the consistent_ flag of the
+ // interpreter.
+ TfLiteStatus SetInputs(std::vector<int> inputs);
+
+ // Provide a list of tensor indexes that are outputs to the model
+ // Each index is bound check and this modifies the consistent_ flag of the
+ // interpreter.
+ TfLiteStatus SetOutputs(std::vector<int> outputs);
+
+ // Adds a node with the given parameters and returns the index of the new
+ // node in `node_index` (optionally). Interpreter will take ownership of
+ // `builtin_data` and destroy it with `delete`. Ownership of 'init_data'
+ // remains with the caller.
+ TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
+ const std::vector<int>& outputs,
+ const char* init_data,
+ size_t init_data_size, void* builtin_data,
+ const TfLiteRegistration* registration,
+ int* node_index = nullptr);
+
+ // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
+ // The value pointed to by `first_new_tensor_index` will be set to the
+ // index of the first new tensor if `first_new_tensor_index` is non-null.
+ TfLiteStatus AddTensors(int tensors_to_add,
+ int* first_new_tensor_index = nullptr);
+
+ // Set description of inputs/outputs/data/fptrs for node `node_index`.
+ // This variant assumes an external buffer has been allocated of size
+ // bytes. The lifetime of buffer must be ensured to be greater or equal
+ // to Interpreter.
+ TfLiteStatus SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
+
+ // Set description of inputs/outputs/data/fptrs for node `node_index`.
+ // This variant assumes an external buffer has been allocated of size
+ // bytes. The lifetime of buffer must be ensured to be greater or equal
+ // to Interpreter.
+ TfLiteStatus SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization);
+
+ // Functions to access tensor data
+
+ // Read only access to list of inputs.
+ const std::vector<int>& inputs() const { return inputs_; }
+
+ // Return the name of a given input. The given index must be between 0 and
+ // inputs().size().
+ const char* GetInputName(int index) const {
+ return context_.tensors[inputs_[index]].name;
+ }
+
+ // Read only access to list of outputs.
+ const std::vector<int>& outputs() const { return outputs_; }
+
+ // Return the name of a given output. The given index must be between 0 and
+ // outputs().size().
+ const char* GetOutputName(int index) const {
+ return context_.tensors[outputs_[index]].name;
+ }
+
+ // Return the number of tensors in the model.
+ int tensors_size() const { return context_.tensors_size; }
+
+ // Return the number of ops in the model.
+ int nodes_size() const { return nodes_and_registration_.size(); }
+
+ // Get a tensor data structure.
+ // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
+ // read/write access to structure
+ TfLiteTensor* tensor(int tensor_index) {
+ if (tensor_index >= context_.tensors_size || tensor_index < 0)
+ return nullptr;
+ return &context_.tensors[tensor_index];
+ }
+
+ // Get a pointer to an operation and registration data structure if in bounds.
+ // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
+ // read/write access to structure
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
+ int node_index) {
+ if (node_index >= nodes_and_registration_.size() || node_index < 0)
+ return nullptr;
+ return &nodes_and_registration_[node_index];
+ }
+
+ // Perform a checked cast to the appropriate tensor type.
+ template <class T>
+ T* typed_tensor(int tensor_index) {
+ if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
+ if (tensor_ptr->type == typeToTfLiteType<T>()) {
+ return reinterpret_cast<T*>(tensor_ptr->data.raw);
+ }
+ }
+ return nullptr;
+ }
+
+ // Return a pointer into the data of a given input tensor. The given index
+ // must be between 0 and inputs().size().
+ template <class T>
+ T* typed_input_tensor(int index) {
+ return typed_tensor<T>(inputs_[index]);
+ }
+
+ // Return a pointer into the data of a given output tensor. The given index
+ // must be between 0 and outputs().size().
+ template <class T>
+ T* typed_output_tensor(int index) {
+ return typed_tensor<T>(outputs_[index]);
+ }
+
+ // Change the dimensionality of a given tensor. Note, this is only acceptable
+ // for tensor indices that are inputs.
+ // Returns status of failure or success.
+ // TODO(aselle): Consider implementing ArraySlice equivalent to make this
+ // more adept at accepting data without an extra copy. Use absl::ArraySlice
+ // if our partners determine that dependency is acceptable.
+ TfLiteStatus ResizeInputTensor(int tensor_index,
+ const std::vector<int>& dims);
+
+ // Update allocations for all tensors. This will redim dependent tensors using
+ // the input tensor dimensionality as given. This is relatively expensive.
+ // If you know that your sizes are not changing, you need not call this.
+
+ // Returns status of success or failure.
+ // TODO(aselle): Madde
+ TfLiteStatus AllocateTensors();
+
+ // Invoke the interpreter (run the whole graph in dependency order).
+ //
+ // NOTE: It is possible that the interpreter is not in a ready state
+ // to evaluate (i.e. if a ResizeTensor() has been performed without an
+ // AllocateTensors().
+ // Returns status of success or failure.
+ TfLiteStatus Invoke();
+
+ // Enable or disable the NN API (true to enable)
+ void UseNNAPI(bool enable);
+
+ // Set the number of threads available to the interpreter.
+ void SetNumThreads(int num_threads);
+
+ private:
+ // Give 'op_reg' a chance to initialize itself using the contents of
+ // 'buffer'.
+ void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
+ size_t length) {
+ if (op_reg.init == nullptr) return nullptr;
+ return op_reg.init(&context_, buffer, length);
+ }
+
+ // Let 'op_reg' release any memory it might have allocated via 'OpInit'.
+ void OpFree(const TfLiteRegistration& op_reg, void* buffer) {
+ if (op_reg.free == nullptr) return;
+ if (buffer) {
+ op_reg.free(&context_, buffer);
+ }
+ }
+
+ // Prepare the given 'node' for execution.
+ TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) {
+ if (op_reg.prepare == nullptr) return kTfLiteOk;
+ return op_reg.prepare(&context_, node);
+ }
+
+ // Invoke the operator represented by 'node'.
+ TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) {
+ if (op_reg.invoke == nullptr) return kTfLiteError;
+ return op_reg.invoke(&context_, node);
+ }
+
+ // Allocate tensors whose sizes are known in order of nodes. Discontinue when
+ // we encounter a node that has a dynamic output tensor.
+ TfLiteStatus AllocateTensorsWhoseSizesAreKnown();
+
+ // Tensors needed by the interpreter. Use `AddTensors` to add more blank
+ // tensor entries. Note, `tensors_.data()` needs to be synchronized to the
+ // `context_` whenever this std::vector is reallocated. Currently this
+ // only happens in `AddTensors()`.
+ std::vector<TfLiteTensor> tensors_;
+
+ // Check if an array of tensor indices are valid with respect to the Tensor
+ // array.
+ // NOTE: this changes consistent_ to be false if indices are out of bounds.
+ TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
+ int length);
+
+ // Compute the number of bytes required to represent a tensor with dimensions
+ // specified by the array dims (of length dims_size). Returns the status code
+ // and bytes.
+ TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size,
+ size_t* bytes);
+
+ // Request an tensor be resized implementation.
+ TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size);
+
+ // Report a detailed error string (will be printed to stderr).
+ // TODO(aselle): allow user of class to provide alternative destinations.
+ void ReportErrorImpl(const char* format, va_list args);
+
+ // Entry point for C node plugin API to request an tensor be resized.
+ static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Entry point for C node plugin API to report an error.
+ static void ReportError(TfLiteContext* context, const char* format, ...);
+
+ // Entry point for C node plugin API to add new tensors.
+ static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // A pure C data structure used to communicate with the pure C plugin
+ // interface. To avoid copying tensor metadata, this is also the definitive
+ // structure to store tensors.
+ TfLiteContext context_;
+
+ // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
+ // function pointers to actual implementation.
+ std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
+ nodes_and_registration_;
+
+ // Raw memory buffer that is allocated for all temporary and graph outputs.
+ // that are declared kTfLiteArenaRw.
+ SimpleMemoryArena arena_;
+
+ // Raw memory buffer that is allocated for persistent tensors that are
+ // declared as kTfLiteArenaRwPersistent.
+ SimpleMemoryArena persistent_arena_;
+
+ // Stores allocation and reference counts of all tensors.
+ std::vector<ArenaAllocRefCount> allocs_and_refcounts_;
+
+ // Whether the model is consistent. That is to say if the inputs and outputs
+ // of every node and the global inputs and outputs are valid indexes into
+ // the tensor array.
+ bool consistent_ = true;
+
+ // Whether the model is safe to invoke (if any errors occurred this
+ // will be false).
+ bool invokable_ = false;
+
+ // Array of indices representing the tensors that are inputs to the
+ // interpreter.
+ std::vector<int> inputs_;
+
+ // Array of indices representing the tensors that are outputs to the
+ // interpreter.
+ std::vector<int> outputs_;
+
+ // The error reporter delegate that tflite will forward queries errors to.
+ ErrorReporter* error_reporter_;
+
+ // Next node to allocate output tensors.
+ // During Invoke(), Interpreter will allocate input tensors first, which are
+ // known to be fixed size. Then it will allocate outputs from nodes as many
+ // as possible. When there is a node that produces dynamic sized tensor.
+ // Intepreter will stop allocating tensors, set the value of next allocate
+ // node id, and execute the node to generate the output tensor before continue
+ // to allocate successors. This process repeats until all nodes are executed.
+ // NOTE: this relies on the order of nodes that is in topological order.
+ int next_allocate_node_id_;
+
+ // Whether to delegate to NN API
+ std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
+};
+
+} // namespace tflite
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
new file mode 100644
index 0000000000..edff210943
--- /dev/null
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -0,0 +1,526 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+// Make an interpreter that has no tensors and no nodes
+TEST(BasicInterpreter, ZeroInterpreter) {
+ Interpreter interpreter;
+ interpreter.SetInputs({});
+ interpreter.SetOutputs({});
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+// Test various error conditions.
+TEST(BasicInterpreter, InvokeInvalidModel) {
+ Interpreter interpreter;
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+// Test size accesser functions.
+TEST(BasicInterpreter, TestSizeFunctions) {
+ Interpreter interpreter;
+ int base_index;
+ ASSERT_EQ(interpreter.nodes_size(), 0);
+ ASSERT_EQ(interpreter.tensors_size(), 0);
+ ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensors_size(), 2);
+ ASSERT_EQ(base_index, 0);
+ ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensors_size(), 5);
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensors_size(), 6);
+ ASSERT_EQ(base_index, 2);
+}
+
+// Test if invalid indices make a model inconsistent (and conversely if
+// valid indices keep a model consistent).
+TEST(BasicInterpreter, InconsistentModel) {
+ // Invalid inputs
+ {
+ Interpreter interpreter;
+ ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk);
+ ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(interpreter.inputs(), std::vector<int>());
+ }
+ // Invalid outputs
+ {
+ Interpreter interpreter;
+ ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk);
+ ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(interpreter.outputs(), std::vector<int>());
+ }
+ // Invalid node inputs
+ {
+ Interpreter interpreter;
+ TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
+ ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr,
+ &registration),
+ kTfLiteOk);
+ ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ }
+ // Valid inputs and outputs and a node with valid inputs and outputs
+ {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
+ &registration),
+ kTfLiteOk);
+ }
+}
+
+// Make an interpreter that has one tensor but no ops
+TEST(BasicInterpreter, CheckAllocate) {
+ struct {
+ TfLiteType type;
+ size_t size;
+ } cases[] = {
+ {kTfLiteFloat32, sizeof(float)},
+ {kTfLiteInt32, sizeof(int32_t)},
+ {kTfLiteUInt8, sizeof(uint8_t)},
+ {kTfLiteInt64, sizeof(int64_t)},
+ };
+
+ for (auto test : cases) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({});
+ TfLiteQuantizationParams quant;
+
+ interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant);
+ interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size);
+ ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size);
+ ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
+ }
+}
+
+TEST(BasicInterpreter, CheckResize) {
+ const float floats[] = {-3., -4.};
+ const int32_t int32s[] = {-3, -4};
+ const uint8_t uint8s[] = {3, 4};
+ const int64_t int64s[] = {6, -7};
+
+ struct {
+ TfLiteType type;
+ size_t size;
+ const char* array;
+ } cases[] = {
+ {kTfLiteFloat32, sizeof(float), reinterpret_cast<const char*>(floats)},
+ {kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
+ {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
+ {kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
+ };
+
+ for (auto test : cases) {
+ Interpreter interpreter;
+
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({});
+ TfLiteQuantizationParams quant;
+
+ ASSERT_EQ(
+ interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadOnly(
+ 1, test.type, "", {2}, quant, test.array, 2 * test.size),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk);
+ // Resizing a mmapped tensor is not allowed and should produce error.
+ ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk);
+ // Set the tensor to be mmapped but with a buffer size that is insufficient
+ // to match the dimensionality.
+ ASSERT_NE(interpreter.SetTensorParametersReadOnly(
+ 1, test.type, "", {2}, quant, test.array, 1 * test.size),
+ kTfLiteOk);
+ // Allocating should work since we should have our last correct array
+ // values in place.
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ }
+}
+
+TEST(BasicInterpreter, CheckAlignment) {
+ struct {
+ TfLiteType type;
+ } cases[] = {
+ {kTfLiteFloat32},
+ {kTfLiteInt32},
+ {kTfLiteUInt8},
+ {kTfLiteInt64},
+ };
+
+ for (auto test : cases) {
+ Interpreter interpreter;
+
+ ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk);
+
+ for (int i = 0; i < 4; i++) {
+ TfLiteQuantizationParams quant;
+ interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1},
+ quant);
+ }
+ interpreter.AllocateTensors();
+ for (int i = 0; i < 4; i++) {
+ const TfLiteTensor& tensor = *interpreter.tensor(i);
+ ASSERT_EQ(reinterpret_cast<intptr_t>(tensor.data.raw) % 4, 0);
+ }
+ }
+}
+
+TEST(BasicInterpreter, CheckArenaAllocation) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk);
+
+ TfLiteQuantizationParams quant;
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ std::vector<int> sizes{2048, 4096, 1023, 2047, 1021,
+ 2047, 1023, 2046, 1021, 2048};
+ for (int i = 0; i < sizes.size(); ++i) {
+ interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]},
+ quant);
+ }
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({9, 4});
+ interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, &reg);
+ interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, &reg);
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw);
+ ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw);
+
+ ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw);
+ ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw);
+ ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
+
+ ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw);
+
+ ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
+}
+
+TEST(BasicInterpreter, BufferAccess) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ // Verify we get a valid pointer.r
+ ASSERT_NE(interpreter.typed_tensor<float>(0), nullptr);
+ // Verify incorrect pointer will not returned.
+ ASSERT_EQ(interpreter.typed_tensor<int>(0), nullptr);
+ // Verify that raw c interface ptr matches safe interface.
+ ASSERT_EQ(interpreter.typed_tensor<float>(0), interpreter.tensor(0)->data.f);
+}
+
+TEST(BasicInterpreter, NoOpInterpreter) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+TEST(BasicInterpreter, OneOpInterpreter) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
+
+ TfLiteQuantizationParams quantized;
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1",
+ {3}, quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0",
+ {3}, quantized),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.GetInputName(0), "in1");
+ ASSERT_EQ(interpreter.GetOutputName(0), "out0");
+
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.init = [](TfLiteContext* context, const char*, size_t) -> void* {
+ auto* first_new_tensor = new int;
+ context->AddTensors(context, 2, first_new_tensor);
+ return first_new_tensor;
+ };
+ reg.free = [](TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+ };
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* first_new_tensor = reinterpret_cast<int*>(node->user_data);
+
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
+
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize));
+
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ for (int i = 0; i < 2; ++i) {
+ node->temporaries->data[i] = *(first_new_tensor) + i;
+ }
+
+ auto setup_temporary = [&](int id) {
+ TfLiteTensor* tmp = &context->tensors[id];
+ tmp->type = kTfLiteFloat32;
+ tmp->allocation_type = kTfLiteArenaRw;
+ return context->ResizeTensor(context, tmp,
+ TfLiteIntArrayCopy(tensor0->dims));
+ };
+ TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0]));
+ TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1]));
+
+ return kTfLiteOk;
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+
+ auto populate = [&](int id) {
+ TfLiteTensor* t = &context->tensors[id];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ t->data.f[i] = a0->data.f[i];
+ }
+ };
+
+ populate(node->outputs->data[0]);
+ populate(node->temporaries->data[0]);
+ populate(node->temporaries->data[1]);
+ return kTfLiteOk;
+ };
+ ASSERT_EQ(
+ interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+}
+
+// Forcefully divides tensor allocation in three steps: one before invocation
+// and two more at invocation time. This happens because we use string tensors
+// and their sizes can't be determined until invocation time.
+TEST(BasicInterpreter, ThreeStepAllocate) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk);
+
+ TfLiteQuantizationParams quantized;
+ char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'};
+ // Read only string tensor.
+ ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1},
+ quantized, data, 15),
+ kTfLiteOk);
+ // Read-write string tensor.
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1},
+ quantized),
+ kTfLiteOk);
+
+ // String-in String-out node.
+ TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
+ reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ DynamicBuffer buf;
+ StringRef str_ref = GetString(a0, 0);
+ buf.AddString(str_ref);
+ buf.WriteToTensor(a1);
+ return kTfLiteOk;
+ };
+
+ // String-in Int-out node.
+ TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr};
+ reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
+ outputSize->data[0] = 1;
+ return context->ResizeTensor(context, output, outputSize);
+ };
+ reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ a1->data.i32[0] = a0->bytes;
+ return kTfLiteOk;
+ };
+
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
+ &reg_copy),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr,
+ &reg_len),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr,
+ &reg_copy),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr,
+ &reg_len),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.tensor(0)->bytes, 15);
+ ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(1)->bytes, 15);
+ ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(3)->bytes, 15);
+ ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(2)->bytes, 4);
+ ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15);
+ ASSERT_EQ(interpreter.tensor(4)->bytes, 4);
+ ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15);
+}
+
+TEST(BasicInterpreter, AllocateTwice) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
+
+ TfLiteQuantizationParams quantized;
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quantized),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quantized),
+ kTfLiteOk);
+
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ return context->ResizeTensor(context, tensor1, newSize);
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ a1->data.f[i] = a0->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+ ASSERT_EQ(
+ interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+ char* old_tensor0_ptr = interpreter.tensor(0)->data.raw;
+ char* old_tensor1_ptr = interpreter.tensor(1)->data.raw;
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw);
+ ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw);
+}
+
+struct TestErrorReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override {
+ char buffer[1024];
+ int size = vsnprintf(buffer, sizeof(buffer), format, args);
+ all_reports += buffer;
+ calls++;
+ return size;
+ }
+ int calls = 0;
+ std::string all_reports;
+};
+
+TEST(BasicInterpreter, TestNullErrorReporter) {
+ TestErrorReporter reporter;
+ Interpreter interpreter;
+}
+
+TEST(BasicInterpreter, TestCustomErrorReporter) {
+ TestErrorReporter reporter;
+ Interpreter interpreter(&reporter);
+ ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
+ ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready.");
+ ASSERT_EQ(reporter.calls, 1);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+#ifdef OS_LINUX
+ FLAGS_logtostderr = true;
+#endif
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
new file mode 100644
index 0000000000..80e8bd435e
--- /dev/null
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -0,0 +1,164 @@
+# Description:
+# TensorFlow Lite Java API.
+
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary")
+
+android_library(
+ name = "tensorflowlite",
+ srcs = glob(
+ [
+ "src/main/java/org/tensorflow/lite/*.java",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tflite_runtime",
+ "@javax_validation",
+ ],
+)
+
+android_library(
+ name = "tensorflowlite_java",
+ srcs = glob(
+ [
+ "src/main/java/org/tensorflow/lite/*.java",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+ deps = [
+ "@javax_validation",
+ ],
+)
+
+java_library(
+ name = "tensorflowlitelib",
+ srcs = glob(
+ [
+ "src/main/java/org/tensorflow/lite/*.java",
+ ],
+ ),
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/java/src/main/native",
+ "@javax_validation",
+ ],
+)
+
+java_test(
+ name = "TensorFlowLiteTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.TensorFlowLiteTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
+ name = "DataTypeTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.DataTypeTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
+ name = "NativeInterpreterWrapperTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java"],
+ data = [
+ "src/testdata/add.bin",
+ "src/testdata/int32.bin",
+ "src/testdata/int64.bin",
+ "src/testdata/invalid.model.tflite",
+ "src/testdata/uint8.bin",
+ ],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+# TODO: generate large models at runtime, instead of storing them.
+java_test(
+ name = "InterpreterTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/InterpreterTest.java"],
+ data = [
+ "src/testdata/add.bin",
+ "src/testdata/mobilenet.tflite.bin",
+ ],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.InterpreterTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
+ name = "TensorTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"],
+ data = [
+ "src/testdata/add.bin",
+ ],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.lite.TensorTest",
+ deps = [
+ ":tensorflowlitelib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+filegroup(
+ name = "libtensorflowlite_jni",
+ srcs = select({
+ "//conditions:default": [":libtensorflowlite_jni.so"],
+ }),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "tflite_runtime",
+ srcs = ["libtensorflowlite_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
+tflite_jni_binary(
+ name = "libtensorflowlite_jni.so",
+ deps = [
+ "//tensorflow/contrib/lite/java/src/main/native",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore
new file mode 100644
index 0000000000..39fb081a42
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/.gitignore
@@ -0,0 +1,9 @@
+*.iml
+.gradle
+/local.properties
+/.idea/workspace.xml
+/.idea/libraries
+.DS_Store
+/build
+/captures
+.externalNativeBuild
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
new file mode 100644
index 0000000000..e1470fe717
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -0,0 +1,58 @@
+apply plugin: 'com.android.application'
+
+android {
+ compileSdkVersion 26
+ buildToolsVersion "26.0.1"
+ defaultConfig {
+ applicationId "android.example.com.tflitecamerademo"
+ minSdkVersion 15
+ targetSdkVersion 26
+ versionCode 1
+ versionName "1.0"
+ testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+
+ // Remove this block.
+ jackOptions {
+ enabled true
+ }
+ }
+ lintOptions {
+ abortOnError false
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
+ }
+ }
+ aaptOptions {
+ noCompress "tflite"
+ }
+
+ compileOptions {
+ sourceCompatibility JavaVersion.VERSION_1_8
+ targetCompatibility JavaVersion.VERSION_1_8
+ }
+}
+
+repositories {
+ flatDir {
+ dirs 'libs'
+ }
+}
+
+dependencies {
+ compile fileTree(dir: 'libs', include: ['*.jar'])
+ androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ exclude group: 'com.android.support', module: 'support-annotations'
+ })
+ compile 'com.android.support:appcompat-v7:25.2.0'
+ compile 'com.android.support.constraint:constraint-layout:1.0.2'
+ compile 'com.android.support:design:25.2.0'
+ compile 'com.android.support:support-annotations:25.3.1'
+ compile 'com.android.support:support-v13:25.2.0'
+
+ compile 'org.tensorflow:tensorflow-lite:+'
+
+ testCompile 'junit:junit:4.12'
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000..ba63dce5d9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
@@ -0,0 +1,42 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.example.android.tflitecamerademo">
+
+ <uses-permission android:name="android.permission.CAMERA" />
+
+ <uses-feature android:name="android.hardware.camera" />
+ <uses-feature android:name="android.hardware.camera.autofocus" />
+
+ <uses-sdk android:minSdkVersion="21" />
+
+ <application android:allowBackup="true"
+ android:label="@string/app_name"
+ android:icon="@drawable/ic_launcher"
+ android:theme="@style/MaterialTheme">
+
+ <activity android:name="com.example.android.tflitecamerademo.CameraActivity"
+ android:label="@string/app_name">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+ </application>
+
+</manifest>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
new file mode 100644
index 0000000000..512a86affe
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -0,0 +1,43 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+android_binary(
+ name = "TfLiteCameraDemo",
+ srcs = glob(["java/**/*.java"]),
+ assets = [
+ ":assets",
+ ],
+ assets_dir = "",
+ custom_package = "com.example.android.tflitecamerademo",
+ manifest = "AndroidManifest.xml",
+ nocompress_extensions = [
+ ".tflite",
+ ],
+ resource_files = glob(["res/**"]),
+ deps = [
+ "//tensorflow/contrib/lite/java:tensorflowlite",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@androidsdk//com.android.support:support-v13-25.2.0",
+ "@androidsdk//com.android.support:support-v4-25.2.0",
+ ],
+)
+
+filegroup(
+ name = "assets",
+ srcs = [
+ "@tflite_mobilenet//:model_files",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
new file mode 100644
index 0000000000..1a759f5652
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "assets_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt
new file mode 100644
index 0000000000..fe811239d8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt
@@ -0,0 +1,1001 @@
+background
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java
new file mode 100644
index 0000000000..f204590659
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.content.Context;
+import android.util.AttributeSet;
+import android.view.TextureView;
+
+/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */
+public class AutoFitTextureView extends TextureView {
+
+ private int mRatioWidth = 0;
+ private int mRatioHeight = 0;
+
+ public AutoFitTextureView(Context context) {
+ this(context, null);
+ }
+
+ public AutoFitTextureView(Context context, AttributeSet attrs) {
+ this(context, attrs, 0);
+ }
+
+ public AutoFitTextureView(Context context, AttributeSet attrs, int defStyle) {
+ super(context, attrs, defStyle);
+ }
+
+ /**
+ * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio
+ * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is,
+ * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result.
+ *
+ * @param width Relative horizontal size
+ * @param height Relative vertical size
+ */
+ public void setAspectRatio(int width, int height) {
+ if (width < 0 || height < 0) {
+ throw new IllegalArgumentException("Size cannot be negative.");
+ }
+ mRatioWidth = width;
+ mRatioHeight = height;
+ requestLayout();
+ }
+
+ @Override
+ protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ int width = MeasureSpec.getSize(widthMeasureSpec);
+ int height = MeasureSpec.getSize(heightMeasureSpec);
+ if (0 == mRatioWidth || 0 == mRatioHeight) {
+ setMeasuredDimension(width, height);
+ } else {
+ if (width < height * mRatioWidth / mRatioHeight) {
+ setMeasuredDimension(width, width * mRatioHeight / mRatioWidth);
+ } else {
+ setMeasuredDimension(height * mRatioWidth / mRatioHeight, height);
+ }
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
new file mode 100644
index 0000000000..74737a8b88
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -0,0 +1,708 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import android.app.AlertDialog;
+import android.app.Dialog;
+import android.app.DialogFragment;
+import android.app.Fragment;
+import android.content.Context;
+import android.content.DialogInterface;
+import android.content.pm.PackageInfo;
+import android.content.pm.PackageManager;
+import android.content.res.Configuration;
+import android.graphics.Bitmap;
+import android.graphics.ImageFormat;
+import android.graphics.Matrix;
+import android.graphics.Point;
+import android.graphics.RectF;
+import android.graphics.SurfaceTexture;
+import android.hardware.camera2.CameraAccessException;
+import android.hardware.camera2.CameraCaptureSession;
+import android.hardware.camera2.CameraCharacteristics;
+import android.hardware.camera2.CameraDevice;
+import android.hardware.camera2.CameraManager;
+import android.hardware.camera2.CaptureRequest;
+import android.hardware.camera2.CaptureResult;
+import android.hardware.camera2.TotalCaptureResult;
+import android.hardware.camera2.params.StreamConfigurationMap;
+import android.media.ImageReader;
+import android.os.Bundle;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.support.annotation.NonNull;
+import android.support.v13.app.FragmentCompat;
+import android.support.v4.content.ContextCompat;
+import android.util.Log;
+import android.util.Size;
+import android.view.LayoutInflater;
+import android.view.Surface;
+import android.view.TextureView;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.TextView;
+import android.widget.Toast;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+/** Basic fragments for the Camera. */
+public class Camera2BasicFragment extends Fragment
+ implements FragmentCompat.OnRequestPermissionsResultCallback {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "TfLiteCameraDemo";
+
+ private static final String FRAGMENT_DIALOG = "dialog";
+
+ private static final String HANDLE_THREAD_NAME = "CameraBackground";
+
+ private static final int PERMISSIONS_REQUEST_CODE = 1;
+
+ private final Object lock = new Object();
+ private boolean runClassifier = false;
+ private boolean checkedPermissions = false;
+ private TextView textView;
+ private ImageClassifier classifier;
+
+ /** Max preview width that is guaranteed by Camera2 API */
+ private static final int MAX_PREVIEW_WIDTH = 1920;
+
+ /** Max preview height that is guaranteed by Camera2 API */
+ private static final int MAX_PREVIEW_HEIGHT = 1080;
+
+ /**
+ * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link
+ * TextureView}.
+ */
+ private final TextureView.SurfaceTextureListener surfaceTextureListener =
+ new TextureView.SurfaceTextureListener() {
+
+ @Override
+ public void onSurfaceTextureAvailable(SurfaceTexture texture, int width, int height) {
+ openCamera(width, height);
+ }
+
+ @Override
+ public void onSurfaceTextureSizeChanged(SurfaceTexture texture, int width, int height) {
+ configureTransform(width, height);
+ }
+
+ @Override
+ public boolean onSurfaceTextureDestroyed(SurfaceTexture texture) {
+ return true;
+ }
+
+ @Override
+ public void onSurfaceTextureUpdated(SurfaceTexture texture) {}
+ };
+
+ /** ID of the current {@link CameraDevice}. */
+ private String cameraId;
+
+ /** An {@link AutoFitTextureView} for camera preview. */
+ private AutoFitTextureView textureView;
+
+ /** A {@link CameraCaptureSession } for camera preview. */
+ private CameraCaptureSession captureSession;
+
+ /** A reference to the opened {@link CameraDevice}. */
+ private CameraDevice cameraDevice;
+
+ /** The {@link android.util.Size} of camera preview. */
+ private Size previewSize;
+
+ /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */
+ private final CameraDevice.StateCallback stateCallback =
+ new CameraDevice.StateCallback() {
+
+ @Override
+ public void onOpened(@NonNull CameraDevice currentCameraDevice) {
+ // This method is called when the camera is opened. We start camera preview here.
+ cameraOpenCloseLock.release();
+ cameraDevice = currentCameraDevice;
+ createCameraPreviewSession();
+ }
+
+ @Override
+ public void onDisconnected(@NonNull CameraDevice currentCameraDevice) {
+ cameraOpenCloseLock.release();
+ currentCameraDevice.close();
+ cameraDevice = null;
+ }
+
+ @Override
+ public void onError(@NonNull CameraDevice currentCameraDevice, int error) {
+ cameraOpenCloseLock.release();
+ currentCameraDevice.close();
+ cameraDevice = null;
+ Activity activity = getActivity();
+ if (null != activity) {
+ activity.finish();
+ }
+ }
+ };
+
+ /** An additional thread for running tasks that shouldn't block the UI. */
+ private HandlerThread backgroundThread;
+
+ /** A {@link Handler} for running tasks in the background. */
+ private Handler backgroundHandler;
+
+ /** An {@link ImageReader} that handles image capture. */
+ private ImageReader imageReader;
+
+ /** {@link CaptureRequest.Builder} for the camera preview */
+ private CaptureRequest.Builder previewRequestBuilder;
+
+ /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */
+ private CaptureRequest previewRequest;
+
+ /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */
+ private Semaphore cameraOpenCloseLock = new Semaphore(1);
+
+ /** A {@link CameraCaptureSession.CaptureCallback} that handles events related to capture. */
+ private CameraCaptureSession.CaptureCallback captureCallback =
+ new CameraCaptureSession.CaptureCallback() {
+
+ @Override
+ public void onCaptureProgressed(
+ @NonNull CameraCaptureSession session,
+ @NonNull CaptureRequest request,
+ @NonNull CaptureResult partialResult) {}
+
+ @Override
+ public void onCaptureCompleted(
+ @NonNull CameraCaptureSession session,
+ @NonNull CaptureRequest request,
+ @NonNull TotalCaptureResult result) {}
+ };
+
+ /**
+ * Shows a {@link Toast} on the UI thread for the classification results.
+ *
+ * @param text The message to show
+ */
+ private void showToast(final String text) {
+ final Activity activity = getActivity();
+ if (activity != null) {
+ activity.runOnUiThread(
+ new Runnable() {
+ @Override
+ public void run() {
+ textView.setText(text);
+ }
+ });
+ }
+ }
+
+ /**
+ * Resizes image.
+ *
+ * Attempting to use too large a preview size could exceed the camera bus' bandwidth limitation,
+ * resulting in gorgeous previews but the storage of garbage capture data.
+ *
+ * Given {@code choices} of {@code Size}s supported by a camera, choose the smallest one that is
+ * at least as large as the respective texture view size, and that is at most as large as the
+ * respective max size, and whose aspect ratio matches with the specified value. If such size
+ * doesn't exist, choose the largest one that is at most as large as the respective max size, and
+ * whose aspect ratio matches with the specified value.
+ *
+ * @param choices The list of sizes that the camera supports for the intended output class
+ * @param textureViewWidth The width of the texture view relative to sensor coordinate
+ * @param textureViewHeight The height of the texture view relative to sensor coordinate
+ * @param maxWidth The maximum width that can be chosen
+ * @param maxHeight The maximum height that can be chosen
+ * @param aspectRatio The aspect ratio
+ * @return The optimal {@code Size}, or an arbitrary one if none were big enough
+ */
+ private static Size chooseOptimalSize(
+ Size[] choices,
+ int textureViewWidth,
+ int textureViewHeight,
+ int maxWidth,
+ int maxHeight,
+ Size aspectRatio) {
+
+ // Collect the supported resolutions that are at least as big as the preview Surface
+ List<Size> bigEnough = new ArrayList<>();
+ // Collect the supported resolutions that are smaller than the preview Surface
+ List<Size> notBigEnough = new ArrayList<>();
+ int w = aspectRatio.getWidth();
+ int h = aspectRatio.getHeight();
+ for (Size option : choices) {
+ if (option.getWidth() <= maxWidth
+ && option.getHeight() <= maxHeight
+ && option.getHeight() == option.getWidth() * h / w) {
+ if (option.getWidth() >= textureViewWidth && option.getHeight() >= textureViewHeight) {
+ bigEnough.add(option);
+ } else {
+ notBigEnough.add(option);
+ }
+ }
+ }
+
+ // Pick the smallest of those big enough. If there is no one big enough, pick the
+ // largest of those not big enough.
+ if (bigEnough.size() > 0) {
+ return Collections.min(bigEnough, new CompareSizesByArea());
+ } else if (notBigEnough.size() > 0) {
+ return Collections.max(notBigEnough, new CompareSizesByArea());
+ } else {
+ Log.e(TAG, "Couldn't find any suitable preview size");
+ return choices[0];
+ }
+ }
+
+ public static Camera2BasicFragment newInstance() {
+ return new Camera2BasicFragment();
+ }
+
+ /** Layout the preview and buttons. */
+ @Override
+ public View onCreateView(
+ LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) {
+ return inflater.inflate(R.layout.fragment_camera2_basic, container, false);
+ }
+
+ /** Connect the buttons to their event handler. */
+ @Override
+ public void onViewCreated(final View view, Bundle savedInstanceState) {
+ textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
+ textView = (TextView) view.findViewById(R.id.text);
+ }
+
+ /** Load the model and labels. */
+ @Override
+ public void onActivityCreated(Bundle savedInstanceState) {
+ super.onActivityCreated(savedInstanceState);
+ try {
+ classifier = new ImageClassifier(getActivity());
+ } catch (IOException e) {
+ Log.e(TAG, "Failed to initialize an image classifier.");
+ }
+ startBackgroundThread();
+ }
+
+ @Override
+ public void onResume() {
+ super.onResume();
+ startBackgroundThread();
+
+ // When the screen is turned off and turned back on, the SurfaceTexture is already
+ // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
+ // a camera and start preview from here (otherwise, we wait until the surface is ready in
+ // the SurfaceTextureListener).
+ if (textureView.isAvailable()) {
+ openCamera(textureView.getWidth(), textureView.getHeight());
+ } else {
+ textureView.setSurfaceTextureListener(surfaceTextureListener);
+ }
+ }
+
+ @Override
+ public void onPause() {
+ closeCamera();
+ stopBackgroundThread();
+ super.onPause();
+ }
+
+ @Override
+ public void onDestroy() {
+ classifier.close();
+ super.onDestroy();
+ }
+
+ /**
+ * Sets up member variables related to camera.
+ *
+ * @param width The width of available size for camera preview
+ * @param height The height of available size for camera preview
+ */
+ private void setUpCameraOutputs(int width, int height) {
+ Activity activity = getActivity();
+ CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
+ try {
+ for (String cameraId : manager.getCameraIdList()) {
+ CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
+
+ // We don't use a front facing camera in this sample.
+ Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING);
+ if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) {
+ continue;
+ }
+
+ StreamConfigurationMap map =
+ characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
+ if (map == null) {
+ continue;
+ }
+
+ // // For still image captures, we use the largest available size.
+ Size largest =
+ Collections.max(
+ Arrays.asList(map.getOutputSizes(ImageFormat.JPEG)), new CompareSizesByArea());
+ imageReader =
+ ImageReader.newInstance(
+ largest.getWidth(), largest.getHeight(), ImageFormat.JPEG, /*maxImages*/ 2);
+
+ // Find out if we need to swap dimension to get the preview size relative to sensor
+ // coordinate.
+ int displayRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
+ // noinspection ConstantConditions
+ /* Orientation of the camera sensor */
+ int sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION);
+ boolean swappedDimensions = false;
+ switch (displayRotation) {
+ case Surface.ROTATION_0:
+ case Surface.ROTATION_180:
+ if (sensorOrientation == 90 || sensorOrientation == 270) {
+ swappedDimensions = true;
+ }
+ break;
+ case Surface.ROTATION_90:
+ case Surface.ROTATION_270:
+ if (sensorOrientation == 0 || sensorOrientation == 180) {
+ swappedDimensions = true;
+ }
+ break;
+ default:
+ Log.e(TAG, "Display rotation is invalid: " + displayRotation);
+ }
+
+ Point displaySize = new Point();
+ activity.getWindowManager().getDefaultDisplay().getSize(displaySize);
+ int rotatedPreviewWidth = width;
+ int rotatedPreviewHeight = height;
+ int maxPreviewWidth = displaySize.x;
+ int maxPreviewHeight = displaySize.y;
+
+ if (swappedDimensions) {
+ rotatedPreviewWidth = height;
+ rotatedPreviewHeight = width;
+ maxPreviewWidth = displaySize.y;
+ maxPreviewHeight = displaySize.x;
+ }
+
+ if (maxPreviewWidth > MAX_PREVIEW_WIDTH) {
+ maxPreviewWidth = MAX_PREVIEW_WIDTH;
+ }
+
+ if (maxPreviewHeight > MAX_PREVIEW_HEIGHT) {
+ maxPreviewHeight = MAX_PREVIEW_HEIGHT;
+ }
+
+ previewSize =
+ chooseOptimalSize(
+ map.getOutputSizes(SurfaceTexture.class),
+ rotatedPreviewWidth,
+ rotatedPreviewHeight,
+ maxPreviewWidth,
+ maxPreviewHeight,
+ largest);
+
+ // We fit the aspect ratio of TextureView to the size of preview we picked.
+ int orientation = getResources().getConfiguration().orientation;
+ if (orientation == Configuration.ORIENTATION_LANDSCAPE) {
+ textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight());
+ } else {
+ textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth());
+ }
+
+ this.cameraId = cameraId;
+ return;
+ }
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ } catch (NullPointerException e) {
+ // Currently an NPE is thrown when the Camera2API is used but not supported on the
+ // device this code runs.
+ ErrorDialog.newInstance(getString(R.string.camera_error))
+ .show(getChildFragmentManager(), FRAGMENT_DIALOG);
+ }
+ }
+
+ private String[] getRequiredPermissions() {
+ Activity activity = getActivity();
+ try {
+ PackageInfo info =
+ activity
+ .getPackageManager()
+ .getPackageInfo(activity.getPackageName(), PackageManager.GET_PERMISSIONS);
+ String[] ps = info.requestedPermissions;
+ if (ps != null && ps.length > 0) {
+ return ps;
+ } else {
+ return new String[0];
+ }
+ } catch (Exception e) {
+ return new String[0];
+ }
+ }
+
+ /** Opens the camera specified by {@link Camera2BasicFragment#cameraId}. */
+ private void openCamera(int width, int height) {
+ if (!checkedPermissions && !allPermissionsGranted()) {
+ FragmentCompat.requestPermissions(this, getRequiredPermissions(), PERMISSIONS_REQUEST_CODE);
+ return;
+ } else {
+ checkedPermissions = true;
+ }
+ setUpCameraOutputs(width, height);
+ configureTransform(width, height);
+ Activity activity = getActivity();
+ CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
+ try {
+ if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
+ throw new RuntimeException("Time out waiting to lock camera opening.");
+ }
+ manager.openCamera(cameraId, stateCallback, backgroundHandler);
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
+ }
+ }
+
+ private boolean allPermissionsGranted() {
+ for (String permission : getRequiredPermissions()) {
+ if (ContextCompat.checkSelfPermission(getActivity(), permission)
+ != PackageManager.PERMISSION_GRANTED) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public void onRequestPermissionsResult(
+ int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
+ super.onRequestPermissionsResult(requestCode, permissions, grantResults);
+ }
+
+ /** Closes the current {@link CameraDevice}. */
+ private void closeCamera() {
+ try {
+ cameraOpenCloseLock.acquire();
+ if (null != captureSession) {
+ captureSession.close();
+ captureSession = null;
+ }
+ if (null != cameraDevice) {
+ cameraDevice.close();
+ cameraDevice = null;
+ }
+ if (null != imageReader) {
+ imageReader.close();
+ imageReader = null;
+ }
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Interrupted while trying to lock camera closing.", e);
+ } finally {
+ cameraOpenCloseLock.release();
+ }
+ }
+
+ /** Starts a background thread and its {@link Handler}. */
+ private void startBackgroundThread() {
+ backgroundThread = new HandlerThread(HANDLE_THREAD_NAME);
+ backgroundThread.start();
+ backgroundHandler = new Handler(backgroundThread.getLooper());
+ synchronized (lock) {
+ runClassifier = true;
+ }
+ backgroundHandler.post(periodicClassify);
+ }
+
+ /** Stops the background thread and its {@link Handler}. */
+ private void stopBackgroundThread() {
+ backgroundThread.quitSafely();
+ try {
+ backgroundThread.join();
+ backgroundThread = null;
+ backgroundHandler = null;
+ synchronized (lock) {
+ runClassifier = false;
+ }
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Takes photos and classify them periodically. */
+ private Runnable periodicClassify =
+ new Runnable() {
+ @Override
+ public void run() {
+ synchronized (lock) {
+ if (runClassifier) {
+ classifyFrame();
+ }
+ }
+ backgroundHandler.post(periodicClassify);
+ }
+ };
+
+ /** Creates a new {@link CameraCaptureSession} for camera preview. */
+ private void createCameraPreviewSession() {
+ try {
+ SurfaceTexture texture = textureView.getSurfaceTexture();
+ assert texture != null;
+
+ // We configure the size of default buffer to be the size of camera preview we want.
+ texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());
+
+ // This is the output Surface we need to start preview.
+ Surface surface = new Surface(texture);
+
+ // We set up a CaptureRequest.Builder with the output Surface.
+ previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
+ previewRequestBuilder.addTarget(surface);
+
+ // Here, we create a CameraCaptureSession for camera preview.
+ cameraDevice.createCaptureSession(
+ Arrays.asList(surface),
+ new CameraCaptureSession.StateCallback() {
+
+ @Override
+ public void onConfigured(@NonNull CameraCaptureSession cameraCaptureSession) {
+ // The camera is already closed
+ if (null == cameraDevice) {
+ return;
+ }
+
+ // When the session is ready, we start displaying the preview.
+ captureSession = cameraCaptureSession;
+ try {
+ // Auto focus should be continuous for camera preview.
+ previewRequestBuilder.set(
+ CaptureRequest.CONTROL_AF_MODE,
+ CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
+
+ // Finally, we start displaying the camera preview.
+ previewRequest = previewRequestBuilder.build();
+ captureSession.setRepeatingRequest(
+ previewRequest, captureCallback, backgroundHandler);
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onConfigureFailed(@NonNull CameraCaptureSession cameraCaptureSession) {
+ showToast("Failed");
+ }
+ },
+ null);
+ } catch (CameraAccessException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Configures the necessary {@link android.graphics.Matrix} transformation to `textureView`. This
+ * method should be called after the camera preview size is determined in setUpCameraOutputs and
+ * also the size of `textureView` is fixed.
+ *
+ * @param viewWidth The width of `textureView`
+ * @param viewHeight The height of `textureView`
+ */
+ private void configureTransform(int viewWidth, int viewHeight) {
+ Activity activity = getActivity();
+ if (null == textureView || null == previewSize || null == activity) {
+ return;
+ }
+ int rotation = activity.getWindowManager().getDefaultDisplay().getRotation();
+ Matrix matrix = new Matrix();
+ RectF viewRect = new RectF(0, 0, viewWidth, viewHeight);
+ RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth());
+ float centerX = viewRect.centerX();
+ float centerY = viewRect.centerY();
+ if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) {
+ bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY());
+ matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL);
+ float scale =
+ Math.max(
+ (float) viewHeight / previewSize.getHeight(),
+ (float) viewWidth / previewSize.getWidth());
+ matrix.postScale(scale, scale, centerX, centerY);
+ matrix.postRotate(90 * (rotation - 2), centerX, centerY);
+ } else if (Surface.ROTATION_180 == rotation) {
+ matrix.postRotate(180, centerX, centerY);
+ }
+ textureView.setTransform(matrix);
+ }
+
+ /** Classifies a frame from the preview stream. */
+ private void classifyFrame() {
+ if (classifier == null || getActivity() == null || cameraDevice == null) {
+ showToast("Uninitialized Classifier or invalid context.");
+ return;
+ }
+ Bitmap bitmap =
+ textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
+ String textToShow = classifier.classifyFrame(bitmap);
+ bitmap.recycle();
+ showToast(textToShow);
+ }
+
+ /** Compares two {@code Size}s based on their areas. */
+ private static class CompareSizesByArea implements Comparator<Size> {
+
+ @Override
+ public int compare(Size lhs, Size rhs) {
+ // We cast here to ensure the multiplications won't overflow
+ return Long.signum(
+ (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
+ }
+ }
+
+ /** Shows an error message dialog. */
+ public static class ErrorDialog extends DialogFragment {
+
+ private static final String ARG_MESSAGE = "message";
+
+ public static ErrorDialog newInstance(String message) {
+ ErrorDialog dialog = new ErrorDialog();
+ Bundle args = new Bundle();
+ args.putString(ARG_MESSAGE, message);
+ dialog.setArguments(args);
+ return dialog;
+ }
+
+ @Override
+ public Dialog onCreateDialog(Bundle savedInstanceState) {
+ final Activity activity = getActivity();
+ return new AlertDialog.Builder(activity)
+ .setMessage(getArguments().getString(ARG_MESSAGE))
+ .setPositiveButton(
+ android.R.string.ok,
+ new DialogInterface.OnClickListener() {
+ @Override
+ public void onClick(DialogInterface dialogInterface, int i) {
+ activity.finish();
+ }
+ })
+ .create();
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java
new file mode 100644
index 0000000000..e7161ddb26
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import android.os.Bundle;
+
+/** Main {@code Activity} class for the Camera app. */
+public class CameraActivity extends Activity {
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_camera);
+ if (null == savedInstanceState) {
+ getFragmentManager()
+ .beginTransaction()
+ .replace(R.id.container, Camera2BasicFragment.newInstance())
+ .commit();
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
new file mode 100644
index 0000000000..e7bad46370
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import android.content.res.AssetFileDescriptor;
+import android.graphics.Bitmap;
+import android.os.SystemClock;
+import android.util.Log;
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import org.tensorflow.lite.Interpreter;
+
+/** Classifies images with Tensorflow Lite. */
+public class ImageClassifier {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "TfLiteCameraDemo";
+
+ /** Name of the model file stored in Assets. */
+ private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";
+
+ /** Name of the label file stored in Assets. */
+ private static final String LABEL_PATH = "labels.txt";
+
+ /** Number of results to show in the UI. */
+ private static final int RESULTS_TO_SHOW = 3;
+
+ /** Dimensions of inputs. */
+ private static final int DIM_BATCH_SIZE = 1;
+
+ private static final int DIM_PIXEL_SIZE = 3;
+
+ static final int DIM_IMG_SIZE_X = 224;
+ static final int DIM_IMG_SIZE_Y = 224;
+
+ /* Preallocated buffers for storing image data in. */
+ private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
+
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private Interpreter tflite;
+
+ /** Labels corresponding to the output of the vision model. */
+ private List<String> labelList;
+
+ /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
+ private ByteBuffer imgData = null;
+
+ /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
+ private byte[][] labelProbArray = null;
+
+ private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
+ new PriorityQueue<>(
+ RESULTS_TO_SHOW,
+ new Comparator<Map.Entry<String, Float>>() {
+ @Override
+ public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
+ return (o1.getValue()).compareTo(o2.getValue());
+ }
+ });
+
+ /** Initializes an {@code ImageClassifier}. */
+ ImageClassifier(Activity activity) throws IOException {
+ tflite = new Interpreter(loadModelFile(activity));
+ labelList = loadLabelList(activity);
+ imgData =
+ ByteBuffer.allocateDirect(
+ DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
+ imgData.order(ByteOrder.nativeOrder());
+ labelProbArray = new byte[1][labelList.size()];
+ Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
+ }
+
+ /** Classifies a frame from the preview stream. */
+ String classifyFrame(Bitmap bitmap) {
+ if (tflite == null) {
+ Log.e(TAG, "Image classifier has not been initialized; Skipped.");
+ return "Uninitialized Classifier.";
+ }
+ convertBitmapToByteBuffer(bitmap);
+ // Here's where the magic happens!!!
+ long startTime = SystemClock.uptimeMillis();
+ tflite.run(imgData, labelProbArray);
+ long endTime = SystemClock.uptimeMillis();
+ Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
+ String textToShow = printTopKLabels();
+ textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
+ return textToShow;
+ }
+
+ /** Closes tflite to release resources. */
+ public void close() {
+ tflite.close();
+ tflite = null;
+ }
+
+ /** Reads label list from Assets. */
+ private List<String> loadLabelList(Activity activity) throws IOException {
+ List<String> labelList = new ArrayList<String>();
+ BufferedReader reader =
+ new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
+ String line;
+ while ((line = reader.readLine()) != null) {
+ labelList.add(line);
+ }
+ reader.close();
+ return labelList;
+ }
+
+ /** Memory-map the model file in Assets. */
+ private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
+ AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+
+ /** Writes Image data into a {@code ByteBuffer}. */
+ private void convertBitmapToByteBuffer(Bitmap bitmap) {
+ if (imgData == null) {
+ return;
+ }
+ imgData.rewind();
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ // Convert the image to floating point.
+ int pixel = 0;
+ long startTime = SystemClock.uptimeMillis();
+ for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
+ for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
+ final int val = intValues[pixel++];
+ imgData.put((byte) ((val >> 16) & 0xFF));
+ imgData.put((byte) ((val >> 8) & 0xFF));
+ imgData.put((byte) (val & 0xFF));
+ }
+ }
+ long endTime = SystemClock.uptimeMillis();
+ Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
+ }
+
+ /** Prints top-K labels, to be shown in UI as the results. */
+ private String printTopKLabels() {
+ for (int i = 0; i < labelList.size(); ++i) {
+ sortedLabels.add(
+ new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f));
+ if (sortedLabels.size() > RESULTS_TO_SHOW) {
+ sortedLabels.poll();
+ }
+ }
+ String textToShow = "";
+ final int size = sortedLabels.size();
+ for (int i = 0; i < size; ++i) {
+ Map.Entry<String, Float> label = sortedLabels.poll();
+ textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow;
+ }
+ return textToShow;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png
new file mode 100644
index 0000000000..e0a70008b1
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
new file mode 100644
index 0000000000..c22509d8df
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png
new file mode 100644
index 0000000000..a84e3ef52c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png
new file mode 100644
index 0000000000..520c2dd100
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
new file mode 100644
index 0000000000..d68af39186
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png
new file mode 100644
index 0000000000..1347b09198
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
new file mode 100644
index 0000000000..15e419b7cc
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png
new file mode 100644
index 0000000000..fd933333b7
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
new file mode 100644
index 0000000000..342ce34e16
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
new file mode 100644
index 0000000000..a84f1bbfa0
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
@@ -0,0 +1,50 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentStart="true"
+ android:layout_alignParentTop="true" />
+
+ <FrameLayout
+ android:id="@+id/control"
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentEnd="true"
+ android:layout_alignParentTop="true"
+ android:layout_toRightOf="@id/texture"
+ android:background="@color/control_background"
+ android:orientation="horizontal">
+
+ <TextView android:id="@+id/text"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:paddingTop="20dp"
+ android:textColor="#FFF"
+ android:textSize="20sp"
+ android:textStyle="bold" />
+
+
+ </FrameLayout>
+
+</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml
new file mode 100644
index 0000000000..286e549c65
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:tools="http://schemas.android.com/tools"
+ android:id="@+id/container"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:background="#000"
+ tools:context="com.example.android.tflitecamerademo.CameraActivity" />
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
new file mode 100644
index 0000000000..15305c436e
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
@@ -0,0 +1,45 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentStart="true"
+ android:layout_alignParentTop="true" />
+
+ <FrameLayout
+ android:id="@+id/control"
+ android:layout_width="match_parent"
+ android:layout_height="112dp"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentStart="true"
+ android:background="@color/control_background">
+
+ <TextView android:id="@+id/text"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:paddingLeft="80dp"
+ android:textColor="#FFF"
+ android:textSize="20sp"
+ android:textStyle="bold" />
+
+ </FrameLayout>
+
+</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml
new file mode 100644
index 0000000000..22074a2bdb
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml
@@ -0,0 +1,24 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Semantic definitions -->
+
+ <dimen name="horizontal_page_margin">@dimen/margin_huge</dimen>
+ <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml
new file mode 100644
index 0000000000..03d1974183
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml
@@ -0,0 +1,25 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <style name="Widget.SampleMessage">
+ <item name="android:textAppearance">?android:textAppearanceLarge</item>
+ <item name="android:lineSpacingMultiplier">1.2</item>
+ <item name="android:shadowDy">-6.5</item>
+ </style>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml
new file mode 100644
index 0000000000..8c1ea66f28
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml
@@ -0,0 +1,22 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Activity themes -->
+ <style name="Theme.Base" parent="android:Theme.Holo.Light" />
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml
new file mode 100644
index 0000000000..8b6ec3f85d
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml
@@ -0,0 +1,21 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<resources>
+
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml
new file mode 100644
index 0000000000..c778e4f98a
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml
@@ -0,0 +1,24 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<resources>
+
+ <!-- Activity themes -->
+ <style name="Theme.Base" parent="android:Theme.Material.Light">
+ </style>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
new file mode 100644
index 0000000000..ab7d3fd496
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<resources>
+ <string name="app_name">TfLiteCameraDemo</string>
+ <string name="intro_message">
+ <![CDATA[
+
+
+ This sample demonstrates the basic use of TfLite API. Check the source code to see how
+ you can use TfLite for efficient, on-device inference with trained TensorFlow models.
+
+
+ ]]>
+ </string>
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml
new file mode 100644
index 0000000000..4b75d2b2bd
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml
@@ -0,0 +1,19 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2015 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <color name="control_background">#cc4285f4</color>
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml
new file mode 100644
index 0000000000..a08ec3eb62
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml
@@ -0,0 +1,24 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <string name="picture">Picture</string>
+ <string name="description_info">Info</string>
+ <string name="request_permission">This sample needs camera permission.</string>
+ <string name="camera_error">This device doesn\'t support Camera2 API.</string>
+ <string name="toggle_turn_on">NN:On</string>
+ <string name="toggle_turn_off">NN:Off</string>
+ <string name="toggle">Use NNAPI</string>
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
new file mode 100644
index 0000000000..3f3bdfb494
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
@@ -0,0 +1,18 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml
new file mode 100644
index 0000000000..39e710b5ca
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml
@@ -0,0 +1,32 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Define standard dimensions to comply with Holo-style grids and rhythm. -->
+
+ <dimen name="margin_tiny">4dp</dimen>
+ <dimen name="margin_small">8dp</dimen>
+ <dimen name="margin_medium">16dp</dimen>
+ <dimen name="margin_large">32dp</dimen>
+ <dimen name="margin_huge">64dp</dimen>
+
+ <!-- Semantic definitions -->
+
+ <dimen name="horizontal_page_margin">@dimen/margin_medium</dimen>
+ <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml
new file mode 100644
index 0000000000..6e7d593dd8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml
@@ -0,0 +1,42 @@
+<!--
+ Copyright 2013 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ -->
+
+<resources>
+
+ <!-- Activity themes -->
+
+ <style name="Theme.Base" parent="android:Theme.Light" />
+
+ <style name="Theme.Sample" parent="Theme.Base" />
+
+ <style name="AppTheme" parent="Theme.Sample" />
+ <!-- Widget styling -->
+
+ <style name="Widget" />
+
+ <style name="Widget.SampleMessage">
+ <item name="android:textAppearance">?android:textAppearanceMedium</item>
+ <item name="android:lineSpacingMultiplier">1.1</item>
+ </style>
+
+ <style name="Widget.SampleMessageTile">
+ <item name="android:background">@drawable/tile</item>
+ <item name="android:shadowColor">#7F000000</item>
+ <item name="android:shadowDy">-3.5</item>
+ <item name="android:shadowRadius">2</item>
+ </style>
+
+</resources>
diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/contrib/lite/java/demo/build.gradle
new file mode 100644
index 0000000000..b78a0b86c9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/build.gradle
@@ -0,0 +1,23 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+
+buildscript {
+ repositories {
+ jcenter()
+ }
+ dependencies {
+ classpath 'com.android.tools.build:gradle:2.3.1'
+
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
+ }
+}
+
+allprojects {
+ repositories {
+ jcenter()
+ }
+}
+
+task clean(type: Delete) {
+ delete rootProject.buildDir
+}
diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/contrib/lite/java/demo/gradle.properties
new file mode 100644
index 0000000000..aac7c9b461
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradle.properties
@@ -0,0 +1,17 @@
+# Project-wide Gradle settings.
+
+# IDE (e.g. Android Studio) users:
+# Gradle settings configured through the IDE *will override*
+# any settings specified in this file.
+
+# For more details on how to configure your build environment visit
+# http://www.gradle.org/docs/current/userguide/build_environment.html
+
+# Specifies the JVM arguments used for the daemon process.
+# The setting is particularly useful for tweaking memory settings.
+org.gradle.jvmargs=-Xmx1536m
+
+# When configured, Gradle will run in incubating parallel mode.
+# This option should only be used with decoupled projects. More details, visit
+# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
+# org.gradle.parallel=true
diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000000..13372aef5e
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000000..fa7a38a0e4
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Thu Sep 28 09:01:41 PDT 2017
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip
diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/contrib/lite/java/demo/gradlew
new file mode 100755
index 0000000000..9d82f78915
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradlew
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn ( ) {
+ echo "$*"
+}
+
+die ( ) {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+esac
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=$((i+1))
+ done
+ case $i in
+ (0) set -- ;;
+ (1) set -- "$args0" ;;
+ (2) set -- "$args0" "$args1" ;;
+ (3) set -- "$args0" "$args1" "$args2" ;;
+ (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
+function splitJvmOpts() {
+ JVM_OPTS=("$@")
+}
+eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
+JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
+
+exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/contrib/lite/java/demo/gradlew.bat
new file mode 100644
index 0000000000..8a0b282aa6
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/gradlew.bat
@@ -0,0 +1,90 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windowz variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+if "%@eval[2+2]" == "4" goto 4NT_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+goto execute
+
+:4NT_args
+@rem Get arguments from the 4NT Shell from JP Software
+set CMD_LINE_ARGS=%$
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/contrib/lite/java/demo/settings.gradle
new file mode 100644
index 0000000000..e7b4def49c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/settings.gradle
@@ -0,0 +1 @@
+include ':app'
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
new file mode 100644
index 0000000000..d63c299589
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** Type of elements in a {@link TfLiteTensor}. */
+enum DataType {
+ /** 32-bit single precision floating point. */
+ FLOAT32(1),
+
+ /** 32-bit signed integer. */
+ INT32(2),
+
+ /** 8-bit unsigned integer. */
+ UINT8(3),
+
+ /** 64-bit signed integer. */
+ INT64(4),
+
+ /** A {@link ByteBuffer}. */
+ BYTEBUFFER(999);
+
+ private final int value;
+
+ DataType(int value) {
+ this.value = value;
+ }
+
+ /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */
+ int getNumber() {
+ return value;
+ }
+
+ /** Converts an integer to the corresponding type. */
+ static DataType fromNumber(int c) {
+ for (DataType t : values) {
+ if (t.value == c) {
+ return t;
+ }
+ }
+ throw new IllegalArgumentException(
+ "DataType " + c + " is not recognized in Java (version " + TensorFlowLite.version() + ")");
+ }
+
+ /** Returns byte size of the type. */
+ int elemByteSize() {
+ switch (this) {
+ case FLOAT32:
+ return 4;
+ case INT32:
+ return 4;
+ case UINT8:
+ return 1;
+ case INT64:
+ return 8;
+ case BYTEBUFFER:
+ return 1;
+ }
+ throw new IllegalArgumentException("DataType " + this + " is not supported yet");
+ }
+
+ // Cached to avoid copying it
+ private static final DataType[] values = values();
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
new file mode 100644
index 0000000000..dd883d69d2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import java.io.File;
+import java.nio.MappedByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+import javax.validation.constraints.NotNull;
+
+/**
+ * Driver class to drive model inference with TensorFlow Lite.
+ *
+ * <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
+ * are executed for model inference.
+ *
+ * <p>For example, if a model takes only one input and returns only one output:
+ *
+ * <pre>{@code
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ * interpreter.run(input, output);
+ * }
+ * }</pre>
+ *
+ * <p>If a model takes multiple inputs or outputs:
+ *
+ * <pre>{@code
+ * Object[] inputs = {input0, input1, ...};
+ * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
+ * float[][][] ith_output = new float[3][2][4];
+ * map_of_indices_to_outputs.put(i, ith_output);
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ * interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+ * }
+ * }</pre>
+ *
+ * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
+ * model with Toco.
+ *
+ * <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code
+ * Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
+ */
+public final class Interpreter implements AutoCloseable {
+
+ /**
+ * Initializes a {@code Interpreter}
+ *
+ * @param modelFile: a File of a pre-trained TF Lite model.
+ */
+ public Interpreter(@NotNull File modelFile) {
+ if (modelFile == null) {
+ return;
+ }
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath());
+ }
+
+ /**
+ * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
+ *
+ * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
+ * Interpreter}.
+ */
+ public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) {
+ wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ }
+
+ /**
+ * Runs model inference if the model takes only one input, and provides only one output.
+ *
+ * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types
+ * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
+ * input data. When {@link ByteBuffer} is used, its content should remain unchanged until
+ * model inference is done.
+ * @param output a multidimensional array of output data.
+ */
+ public void run(@NotNull Object input, @NotNull Object output) {
+ Object[] inputs = {input};
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, output);
+ runForMultipleInputsOutputs(inputs, outputs);
+ }
+
+ /**
+ * Runs model inference if the model takes multiple inputs, or returns multiple outputs.
+ *
+ * @param inputs an array of input data. The inputs should be in the same order as inputs of the
+ * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of
+ * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
+ * way to pass large input data. When {@link ByteBuffer} is used, its content should remain
+ * unchanged until model inference is done.
+ * @param outputs a map mapping output indices to multidimensional arrays of output data. It only
+ * needs to keep entries for the outputs to be used.
+ */
+ public void runForMultipleInputsOutputs(
+ @NotNull Object[] inputs, @NotNull Map<Integer, Object> outputs) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ Tensor[] tensors = wrapper.run(inputs);
+ if (outputs == null || tensors == null || outputs.size() > tensors.length) {
+ throw new IllegalArgumentException("Outputs do not match with model outputs.");
+ }
+ final int size = tensors.length;
+ for (Integer idx : outputs.keySet()) {
+ if (idx == null || idx < 0 || idx >= size) {
+ throw new IllegalArgumentException(
+ String.format("Invalid index of output %d (should be in range [0, %d))", idx, size));
+ }
+ tensors[idx].copyTo(outputs.get(idx));
+ }
+ }
+
+ /**
+ * Resizes idx-th input of the native model to the given dims.
+ *
+ * <p>IllegalArgumentException will be thrown if it fails to resize.
+ */
+ public void resizeInput(int idx, @NotNull int[] dims) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ wrapper.resizeInput(idx, dims);
+ }
+
+ /**
+ * Gets index of an input given the op name of the input.
+ *
+ * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
+ * to initialize the {@link Interpreter}.
+ */
+ public int getInputIndex(String opName) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ return wrapper.getInputIndex(opName);
+ }
+
+ /**
+ * Gets index of an output given the op name of the output.
+ *
+ * <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
+ * to initialize the {@link Interpreter}.
+ */
+ public int getOutputIndex(String opName) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The Interpreter has already been closed.");
+ }
+ return wrapper.getOutputIndex(opName);
+ }
+
+ /** Release resources associated with the {@code Interpreter}. */
+ @Override
+ public void close() {
+ wrapper.close();
+ wrapper = null;
+ }
+
+ NativeInterpreterWrapper wrapper;
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
new file mode 100644
index 0000000000..1939a078ad
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -0,0 +1,276 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import java.lang.reflect.Array;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A wrapper wraps native interpreter and controls model execution.
+ *
+ * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
+ * explicitly freed by invoking the {@link #close()} method when the {@code
+ * NativeInterpreterWrapper} object is no longer needed.
+ */
+final class NativeInterpreterWrapper implements AutoCloseable {
+
+ NativeInterpreterWrapper(String modelPath) {
+ errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
+ modelHandle = createModel(modelPath, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle);
+ }
+
+ /**
+ * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The
+ * MappedByteBuffer should not be modified after the construction of a {@code
+ * NativeInterpreterWrapper}.
+ */
+ NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
+ modelByteBuffer = mappedByteBuffer;
+ errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
+ modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
+ interpreterHandle = createInterpreter(modelHandle);
+ }
+
+ /** Releases resources associated with this {@code NativeInterpreterWrapper}. */
+ @Override
+ public void close() {
+ delete(errorHandle, modelHandle, interpreterHandle);
+ errorHandle = 0;
+ modelHandle = 0;
+ interpreterHandle = 0;
+ modelByteBuffer = null;
+ inputsIndexes = null;
+ outputsIndexes = null;
+ }
+
+ /** Sets inputs, runs model inference and returns outputs. */
+ Tensor[] run(Object[] inputs) {
+ if (inputs == null || inputs.length == 0) {
+ throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty.");
+ }
+ int[] dataTypes = new int[inputs.length];
+ Object[] sizes = new Object[inputs.length];
+ int[] numsOfBytes = new int[inputs.length];
+ for (int i = 0; i < inputs.length; ++i) {
+ DataType dataType = dataTypeOf(inputs[i]);
+ dataTypes[i] = dataType.getNumber();
+ if (dataType == DataType.BYTEBUFFER) {
+ ByteBuffer buffer = (ByteBuffer) inputs[i];
+ if (buffer.order() != ByteOrder.nativeOrder()) {
+ throw new IllegalArgumentException(
+ "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder().");
+ }
+ numsOfBytes[i] = buffer.limit();
+ sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
+ } else if (isNonEmptyArray(inputs[i])) {
+ int[] dims = shapeOf(inputs[i]);
+ sizes[i] = dims;
+ numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "%d-th element of the %d inputs is not an array or a ByteBuffer.",
+ i, inputs.length));
+ }
+ }
+ long[] outputsHandles =
+ run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs);
+ if (outputsHandles == null || outputsHandles.length == 0) {
+ throw new IllegalStateException("Interpreter has no outputs.");
+ }
+ Tensor[] outputs = new Tensor[outputsHandles.length];
+ for (int i = 0; i < outputsHandles.length; ++i) {
+ outputs[i] = Tensor.fromHandle(outputsHandles[i]);
+ }
+ return outputs;
+ }
+
+ /** Resizes dimensions of a specific input. */
+ void resizeInput(int idx, int[] dims) {
+ resizeInput(interpreterHandle, errorHandle, idx, dims);
+ }
+
+ void setUseNNAPI(boolean useNNAPI) {
+ useNNAPI(interpreterHandle, useNNAPI);
+ }
+
+ /** Gets index of an input given its name. */
+ int getInputIndex(String name) {
+ if (inputsIndexes == null) {
+ String[] names = getInputNames(interpreterHandle);
+ inputsIndexes = new HashMap<>();
+ if (names != null) {
+ for (int i = 0; i < names.length; ++i) {
+ inputsIndexes.put(names[i], i);
+ }
+ }
+ }
+ if (inputsIndexes.containsKey(name)) {
+ return inputsIndexes.get(name);
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s is not a valid name for any input. The indexes of the inputs are %s",
+ name, inputsIndexes.toString()));
+ }
+ }
+
+ /** Gets index of an output given its name. */
+ int getOutputIndex(String name) {
+ if (outputsIndexes == null) {
+ String[] names = getOutputNames(interpreterHandle);
+ outputsIndexes = new HashMap<>();
+ if (names != null) {
+ for (int i = 0; i < names.length; ++i) {
+ outputsIndexes.put(names[i], i);
+ }
+ }
+ }
+ if (outputsIndexes.containsKey(name)) {
+ return outputsIndexes.get(name);
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s is not a valid name for any output. The indexes of the outputs are %s",
+ name, outputsIndexes.toString()));
+ }
+ }
+
+ static int numElements(int[] shape) {
+ if (shape == null) {
+ return 0;
+ }
+ int n = 1;
+ for (int i = 0; i < shape.length; i++) {
+ n *= shape[i];
+ }
+ return n;
+ }
+
+ static boolean isNonEmptyArray(Object o) {
+ return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
+ }
+
+ /** Returns the type of the data. */
+ static DataType dataTypeOf(Object o) {
+ if (o != null) {
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
+ }
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ } else if (ByteBuffer.class.isInstance(o)) {
+ return DataType.BYTEBUFFER;
+ }
+ }
+ throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName());
+ }
+
+ /** Returns the shape of an object as an int array. */
+ static int[] shapeOf(Object o) {
+ int size = numDimensions(o);
+ int[] dimensions = new int[size];
+ fillShape(o, 0, dimensions);
+ return dimensions;
+ }
+
+ static int numDimensions(Object o) {
+ if (o == null || !o.getClass().isArray()) {
+ return 0;
+ }
+ if (Array.getLength(o) == 0) {
+ throw new IllegalArgumentException("array lengths cannot be 0.");
+ }
+ return 1 + numDimensions(Array.get(o, 0));
+ }
+
+ static void fillShape(Object o, int dim, int[] shape) {
+ if (shape == null || dim == shape.length) {
+ return;
+ }
+ final int len = Array.getLength(o);
+ if (shape[dim] == 0) {
+ shape[dim] = len;
+ } else if (shape[dim] != len) {
+ throw new IllegalArgumentException(
+ String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
+ }
+ for (int i = 0; i < len; ++i) {
+ fillShape(Array.get(o, i), dim + 1, shape);
+ }
+ }
+
+ private static final int ERROR_BUFFER_SIZE = 512;
+
+ private long errorHandle;
+
+ private long interpreterHandle;
+
+ private long modelHandle;
+
+ private int inputSize;
+
+ private MappedByteBuffer modelByteBuffer;
+
+ private Map<String, Integer> inputsIndexes;
+
+ private Map<String, Integer> outputsIndexes;
+
+ private static native String[] getInputNames(long interpreterHandle);
+
+ private static native String[] getOutputNames(long interpreterHandle);
+
+ private static native void resizeInput(
+ long interpreterHandle, long errorHandle, int inputIdx, int[] dims);
+
+ private static native void useNNAPI(long interpreterHandle, boolean state);
+
+ private static native long createErrorReporter(int size);
+
+ private static native long createModel(String modelPathOrBuffer, long errorHandle);
+
+ private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle);
+
+ private static native long createInterpreter(long modelHandle);
+
+ private static native long[] run(
+ long interpreterHandle,
+ long errorHandle,
+ Object[] sizes,
+ int[] dtypes,
+ int[] numsOfBytes,
+ Object[] values);
+
+ private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
+
+ private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
+
+ static {
+ TensorFlowLite.init();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
new file mode 100644
index 0000000000..54ace6c63c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import java.util.Arrays;
+
+/**
+ * A typed multi-dimensional array used in Tensorflow Lite.
+ *
+ * <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
+ * needed to be closed here.
+ */
+final class Tensor {
+
+ static Tensor fromHandle(long nativeHandle) {
+ return new Tensor(nativeHandle);
+ }
+
+ /** Reads Tensor content into an array. */
+ <T> T copyTo(T dst) {
+ if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot convert an TensorFlowLite tensor with type %s to a Java object of "
+ + "type %s (which is compatible with the TensorFlowLite type %s)",
+ dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst)));
+ }
+ int[] dstShape = NativeInterpreterWrapper.shapeOf(dst);
+ if (!Arrays.equals(dstShape, shapeCopy)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Shape of output target %s does not match with the shape of the Tensor %s.",
+ Arrays.toString(dstShape), Arrays.toString(shapeCopy)));
+ }
+ readMultiDimensionalArray(nativeHandle, dst);
+ return dst;
+ }
+
+ final long nativeHandle;
+ final DataType dtype;
+ final int[] shapeCopy;
+
+ private Tensor(long nativeHandle) {
+ this.nativeHandle = nativeHandle;
+ this.dtype = DataType.fromNumber(dtype(nativeHandle));
+ this.shapeCopy = shape(nativeHandle);
+ }
+
+ private static native int dtype(long handle);
+
+ private static native int[] shape(long handle);
+
+ private static native void readMultiDimensionalArray(long handle, Object value);
+
+ static {
+ TensorFlowLite.init();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
new file mode 100644
index 0000000000..711638a9f9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** Static utility methods loading the TensorFlowLite runtime. */
+public final class TensorFlowLite {
+
+ private static final String LIBNAME = "tensorflowlite_jni";
+
+ private TensorFlowLite() {}
+
+ /** Returns the version of the underlying TensorFlowLite runtime. */
+ public static native String version();
+
+ /**
+ * Load the TensorFlowLite runtime C library.
+ */
+ static boolean init() {
+ try {
+ System.loadLibrary(LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
+ return false;
+ }
+ }
+
+ static {
+ init();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java
new file mode 100644
index 0000000000..68e6a0f578
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java
@@ -0,0 +1,17 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/** Defines classes to load and execute TensorFlowLite models. */
+package org.tensorflow.lite;
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
new file mode 100644
index 0000000000..9c172a1f68
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -0,0 +1,70 @@
+# Description:
+# Java Native Interface (JNI) library intended for implementing the
+# TensorFlow Lite Java API using the TensorFlow Lite CC library.
+
+package(default_visibility = ["//visibility:public"])
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "native_framework_only",
+ srcs = [
+ "exception_jni.cc",
+ "nativeinterpreterwrapper_jni.cc",
+ "tensor_jni.cc",
+ "tensorflow_lite_jni.cc",
+ ],
+ hdrs = [
+ "exception_jni.h",
+ "nativeinterpreterwrapper_jni.h",
+ "tensor_jni.h",
+ "tensorflow_lite_jni.h",
+ ],
+ copts = tflite_copts(),
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ ],
+ alwayslink = 1,
+)
+
+# This includes all ops. If you want a smaller binary, you should copy and
+# modify builtin_ops_jni.cc. You should then link your binary against both
+# ":native_framework_only" and your own version of ":native_builtin_ops".
+cc_library(
+ name = "native",
+ srcs = [
+ "builtin_ops_jni.cc",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":native_framework_only",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+ alwayslink = 1,
+)
+
+exports_files(
+ [
+ "version_script.lds",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc
new file mode 100644
index 0000000000..cce356370f
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc
@@ -0,0 +1,29 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/register.h"
+
+namespace tflite {
+
+// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
+// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
+// builtin ops. For smaller binary sizes users should avoid linking this in, and
+// should provide a custom make CreateOpResolver() instead.
+std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
+ return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
+ new tflite::ops::builtin::BuiltinOpResolver());
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc
new file mode 100644
index 0000000000..1578c9e3dd
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
+
+const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
+const char kIllegalStateException[] = "java/lang/IllegalStateException";
+const char kNullPointerException[] = "java/lang/NullPointerException";
+const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException";
+const char kUnsupportedOperationException[] =
+ "java/lang/UnsupportedOperationException";
+
+void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
+ va_list args;
+ va_start(args, fmt);
+ const size_t max_msg_len = 512;
+ auto* message = static_cast<char*>(malloc(max_msg_len));
+ if (vsnprintf(message, max_msg_len, fmt, args) >= 0) {
+ env->ThrowNew(env->FindClass(clazz), message);
+ } else {
+ env->ThrowNew(env->FindClass(clazz), "");
+ }
+ free(message);
+ va_end(args);
+}
+
+BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) {
+ buffer_ = new char[limit];
+ if (!buffer_) {
+ throwException(env, kNullPointerException,
+ "Malloc of BufferErrorReporter to hold %d char failed.",
+ limit);
+ return;
+ }
+ start_idx_ = 0;
+ end_idx_ = limit - 1;
+}
+
+BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; }
+
+int BufferErrorReporter::Report(const char* format, va_list args) {
+ int size = 0;
+ if (start_idx_ < end_idx_) {
+ size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args);
+ }
+ start_idx_ += size;
+ return size;
+}
+
+const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; }
diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
new file mode 100644
index 0000000000..3ffff052df
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
+
+#include <jni.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern const char kIllegalArgumentException[];
+extern const char kIllegalStateException[];
+extern const char kNullPointerException[];
+extern const char kIndexOutOfBoundsException[];
+extern const char kUnsupportedOperationException[];
+
+void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...);
+
+class BufferErrorReporter : public tflite::ErrorReporter {
+ public:
+ BufferErrorReporter(JNIEnv* env, int limit);
+ virtual ~BufferErrorReporter();
+ int Report(const char* format, va_list args) override;
+ const char* CachedErrorMessage();
+
+ private:
+ char* buffer_;
+ int start_idx_ = 0;
+ int end_idx_ = 0;
+};
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
new file mode 100644
index 0000000000..bc6462eb54
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -0,0 +1,446 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
+
+namespace {
+
+const int kByteBufferValue = 999;
+const int kBufferSize = 256;
+
+tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Invalid handle to Interpreter.");
+ return nullptr;
+ }
+ return reinterpret_cast<tflite::Interpreter*>(handle);
+}
+
+tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException, "Invalid handle to model.");
+ return nullptr;
+ }
+ return reinterpret_cast<tflite::FlatBufferModel*>(handle);
+}
+
+BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Invalid handle to ErrorReporter.");
+ return nullptr;
+ }
+ return reinterpret_cast<BufferErrorReporter*>(handle);
+}
+
+std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
+ int size = static_cast<int>(env->GetArrayLength(inputs));
+ std::vector<int> outputs(size, 0);
+ jint* ptr = env->GetIntArrayElements(inputs, nullptr);
+ if (ptr == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Empty dimensions of input array.");
+ return {};
+ }
+ for (int i = 0; i < size; ++i) {
+ outputs[i] = ptr[i];
+ }
+ env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT);
+ return outputs;
+}
+
+bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; }
+
+TfLiteType resolveDataType(jint data_type) {
+ switch (data_type) {
+ case 1:
+ return kTfLiteFloat32;
+ case 2:
+ return kTfLiteInt32;
+ case 3:
+ return kTfLiteUInt8;
+ case 4:
+ return kTfLiteInt64;
+ default:
+ return kTfLiteNoType;
+ }
+}
+
+void printDims(char* buffer, int max_size, int* dims, int num_dims) {
+ if (max_size <= 0) return;
+ buffer[0] = '?';
+ int size = 1;
+ for (int i = 1; i < num_dims; ++i) {
+ if (max_size > size) {
+ int written_size =
+ snprintf(buffer + size, max_size - size, ",%d", dims[i]);
+ if (written_size < 0) return;
+ size += written_size;
+ }
+ }
+}
+
+TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter,
+ const int input_size, jintArray data_types,
+ jintArray nums_of_bytes, jobjectArray values,
+ jobjectArray sizes) {
+ if (input_size != interpreter->inputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Expected num of inputs is %d but got %d",
+ interpreter->inputs().size(), input_size);
+ return kTfLiteError;
+ }
+ if (input_size != env->GetArrayLength(data_types) ||
+ input_size != env->GetArrayLength(nums_of_bytes) ||
+ input_size != env->GetArrayLength(values)) {
+ throwException(env, kIllegalArgumentException,
+ "Arrays in arguments should be of the same length, but got "
+ "%d sizes, %d data_types, %d nums_of_bytes, and %d values",
+ input_size, env->GetArrayLength(data_types),
+ env->GetArrayLength(nums_of_bytes),
+ env->GetArrayLength(values));
+ return kTfLiteError;
+ }
+ for (int i = 0; i < input_size; ++i) {
+ int input_idx = interpreter->inputs()[i];
+ TfLiteTensor* target = interpreter->tensor(input_idx);
+ jintArray dims =
+ static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
+ int num_dims = static_cast<int>(env->GetArrayLength(dims));
+ if (target->dims->size != num_dims) {
+ throwException(env, kIllegalArgumentException,
+ "%d-th input should have %d dimensions, but found %d "
+ "dimensions",
+ i, target->dims->size, num_dims);
+ return kTfLiteError;
+ }
+ jint* ptr = env->GetIntArrayElements(dims, nullptr);
+ for (int j = 1; j < num_dims; ++j) {
+ if (target->dims->data[j] != ptr[j]) {
+ std::unique_ptr<char[]> expected_dims(new char[kBufferSize]);
+ std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]);
+ printDims(expected_dims.get(), kBufferSize, target->dims->data,
+ num_dims);
+ printDims(obtained_dims.get(), kBufferSize, ptr, num_dims);
+ throwException(env, kIllegalArgumentException,
+ "%d-th input dimension should be [%s], but found [%s]",
+ i, expected_dims.get(), obtained_dims.get());
+ env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
+ return kTfLiteError;
+ }
+ }
+ env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
+ env->DeleteLocalRef(dims);
+ if (env->ExceptionCheck()) return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter,
+ int input_size, jobjectArray sizes) {
+ for (int i = 0; i < input_size; ++i) {
+ int input_idx = interpreter->inputs()[i];
+ jintArray dims =
+ static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
+ TfLiteStatus status = interpreter->ResizeInputTensor(
+ input_idx, convertJIntArrayToVector(env, dims));
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ env->DeleteLocalRef(dims);
+ if (env->ExceptionCheck()) return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter,
+ int input_size, jintArray data_types,
+ jintArray nums_of_bytes, jobjectArray values) {
+ jint* data_type = env->GetIntArrayElements(data_types, nullptr);
+ jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr);
+ for (int i = 0; i < input_size; ++i) {
+ int input_idx = interpreter->inputs()[i];
+ TfLiteTensor* target = interpreter->tensor(input_idx);
+ jobject value = env->GetObjectArrayElement(values, i);
+ bool is_byte_buffer = isByteBuffer(data_type[i]);
+ if (is_byte_buffer) {
+ writeByteBuffer(env, value, &(target->data.raw),
+ static_cast<int>(num_bytes[i]));
+ } else {
+ TfLiteType type = resolveDataType(data_type[i]);
+ if (type != target->type) {
+ throwException(env, kIllegalArgumentException,
+ "DataType (%d) of input data does not match with the "
+ "DataType (%d) of model inputs.",
+ type, target->type);
+ return kTfLiteError;
+ }
+ writeMultiDimensionalArray(env, value, target->type, target->dims->size,
+ &(target->data.raw),
+ static_cast<int>(num_bytes[i]));
+ }
+ env->DeleteLocalRef(value);
+ if (env->ExceptionCheck()) return kTfLiteError;
+ }
+ env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT);
+ env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT);
+ return kTfLiteOk;
+}
+
+} // namespace
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return nullptr;
+ jclass string_class = env->FindClass("java/lang/String");
+ if (string_class == nullptr) {
+ throwException(env, kUnsupportedOperationException,
+ "Can not find java/lang/String class to get input names.");
+ return nullptr;
+ }
+ size_t size = interpreter->inputs().size();
+ jobjectArray names = static_cast<jobjectArray>(
+ env->NewObjectArray(size, string_class, env->NewStringUTF("")));
+ for (int i = 0; i < size; ++i) {
+ env->SetObjectArrayElement(names, i,
+ env->NewStringUTF(interpreter->GetInputName(i)));
+ }
+ return names;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return nullptr;
+ jclass string_class = env->FindClass("java/lang/String");
+ if (string_class == nullptr) {
+ throwException(env, kUnsupportedOperationException,
+ "Can not find java/lang/String class to get output names.");
+ return nullptr;
+ }
+ size_t size = interpreter->outputs().size();
+ jobjectArray names = static_cast<jobjectArray>(
+ env->NewObjectArray(size, string_class, env->NewStringUTF("")));
+ for (int i = 0; i < size; ++i) {
+ env->SetObjectArrayElement(
+ names, i, env->NewStringUTF(interpreter->GetOutputName(i)));
+ }
+ return names;
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jboolean state) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->UseNNAPI(static_cast<bool>(state));
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
+ JNIEnv* env, jclass clazz, jint size) {
+ BufferErrorReporter* error_reporter =
+ new BufferErrorReporter(env, static_cast<int>(size));
+ return reinterpret_cast<jlong>(error_reporter);
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
+ JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return 0;
+ const char* path = env->GetStringUTFChars(model_file, nullptr);
+ auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter);
+ if (!model) {
+ throwException(env, kIllegalArgumentException,
+ "Contents of %s does not encode a valid TensorFlowLite "
+ "model: %s",
+ path, error_reporter->CachedErrorMessage());
+ env->ReleaseStringUTFChars(model_file, path);
+ return 0;
+ }
+ env->ReleaseStringUTFChars(model_file, path);
+ return reinterpret_cast<jlong>(model.release());
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
+ JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) {
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return 0;
+ const char* buf =
+ static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
+ jlong capacity = env->GetDirectBufferCapacity(model_buffer);
+ auto model = tflite::FlatBufferModel::BuildFromBuffer(
+ buf, static_cast<size_t>(capacity), error_reporter);
+ if (!model) {
+ throwException(env, kIllegalArgumentException,
+ "MappedByteBuffer does not encode a valid TensorFlowLite "
+ "model: %s",
+ error_reporter->CachedErrorMessage());
+ return 0;
+ }
+ return reinterpret_cast<jlong>(model.release());
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
+ JNIEnv* env, jclass clazz, jlong model_handle) {
+ tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
+ if (model == nullptr) return 0;
+ auto resolver = ::tflite::CreateOpResolver();
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
+ return reinterpret_cast<jlong>(interpreter.release());
+}
+
+// Sets inputs, runs inference, and returns outputs as long handles.
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
+ jobjectArray values) {
+ tflite::Interpreter* interpreter =
+ convertLongToInterpreter(env, interpreter_handle);
+ if (interpreter == nullptr) return nullptr;
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return nullptr;
+ const int input_size = env->GetArrayLength(sizes);
+ // validates inputs
+ TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types,
+ nums_of_bytes, values, sizes);
+ if (status != kTfLiteOk) return nullptr;
+ // resizes inputs
+ status = resizeInputs(env, interpreter, input_size, sizes);
+ if (status != kTfLiteOk) {
+ throwException(env, kNullPointerException, "Can not resize the input: %s",
+ error_reporter->CachedErrorMessage());
+ return nullptr;
+ }
+ // allocates memory
+ status = interpreter->AllocateTensors();
+ if (status != kTfLiteOk) {
+ throwException(env, kNullPointerException,
+ "Can not allocate memory for the given inputs: %s",
+ error_reporter->CachedErrorMessage());
+ return nullptr;
+ }
+ // sets inputs
+ status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
+ values);
+ if (status != kTfLiteOk) return nullptr;
+ // runs inference
+ if (interpreter->Invoke() != kTfLiteOk) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to run on the given Interpreter: %s",
+ error_reporter->CachedErrorMessage());
+ return nullptr;
+ }
+ // returns outputs
+ const std::vector<int>& results = interpreter->outputs();
+ if (results.empty()) {
+ throwException(env, kIllegalArgumentException,
+ "The Interpreter does not have any outputs.");
+ return nullptr;
+ }
+ jlongArray outputs = env->NewLongArray(results.size());
+ size_t size = results.size();
+ for (int i = 0; i < size; ++i) {
+ TfLiteTensor* source = interpreter->tensor(results[i]);
+ jlong output = reinterpret_cast<jlong>(source);
+ env->SetLongArrayRegion(outputs, i, 1, &output);
+ }
+ return outputs;
+}
+
+JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return nullptr;
+ const int idx = static_cast<int>(input_idx);
+ if (input_idx >= interpreter->inputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Out of range: Failed to get %d-th input out of %d inputs",
+ input_idx, interpreter->inputs().size());
+ return nullptr;
+ }
+ TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
+ int size = target->dims->size;
+ int expected_num_bytes = elementByteSize(target->type);
+ for (int i = 0; i < size; ++i) {
+ expected_num_bytes *= target->dims->data[i];
+ }
+ if (num_bytes != expected_num_bytes) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to get input dimensions. %d-th input should have"
+ " %d bytes, but found %d bytes.",
+ idx, expected_num_bytes, num_bytes);
+ return nullptr;
+ }
+ jintArray outputs = env->NewIntArray(size);
+ env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
+ return outputs;
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jint input_idx, jintArray dims) {
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return;
+ tflite::Interpreter* interpreter =
+ convertLongToInterpreter(env, interpreter_handle);
+ if (interpreter == nullptr) return;
+ const int idx = static_cast<int>(input_idx);
+ if (idx < 0 || idx >= interpreter->inputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Can not resize %d-th input for a model having %d inputs.",
+ idx, interpreter->inputs().size());
+ }
+ TfLiteStatus status = interpreter->ResizeInputTensor(
+ interpreter->inputs()[idx], convertJIntArrayToVector(env, dims));
+ if (status != kTfLiteOk) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to resize %d-th input: %s", idx,
+ error_reporter->CachedErrorMessage());
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
+ JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle,
+ jlong interpreter_handle) {
+ if (interpreter_handle != 0) {
+ delete convertLongToInterpreter(env, interpreter_handle);
+ }
+ if (model_handle != 0) {
+ delete convertLongToModel(env, model_handle);
+ }
+ if (error_handle != 0) {
+ delete convertLongToErrorReporter(env, error_handle);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
new file mode 100644
index 0000000000..430886b7cc
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
+
+#include <jni.h>
+#include <stdio.h>
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
+#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+// This is to be provided at link-time by a library.
+extern std::unique_ptr<OpResolver> CreateOpResolver();
+} // namespace tflite
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (J)[Ljava/lang/Object;
+ */
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (J)[Ljava/lang/Object;
+ */
+JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JZ)
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jboolean state);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (I)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
+ JNIEnv* env, jclass clazz, jint size);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (Ljava/lang/String;J)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
+ JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (Ljava/lang/Object;J)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
+ JNIEnv* env, jclass clazz, jobject model_buffer, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
+ JNIEnv* env, jclass clazz, jlong model_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J
+ */
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
+ jobjectArray values);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JII)[I
+ *
+ * It gets input dimensions if num_bytes matches number of bytes required by
+ * the input, else returns null and throws IllegalArgumentException.
+ */
+JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JJI[I)
+ *
+ * It resizes dimensions of a input.
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
+ jint input_idx, jintArray dims);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JJJ)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
+ JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle,
+ jlong interpreter_handle);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
new file mode 100644
index 0000000000..65126e78a3
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -0,0 +1,242 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
+#include <cstring>
+#include <memory>
+#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
+
+namespace {
+
+TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
+ if (handle == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Invalid handle to TfLiteTensor.");
+ return nullptr;
+ }
+ return reinterpret_cast<TfLiteTensor*>(handle);
+}
+
+size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
+ void* dst, size_t dst_size) {
+ jarray array = static_cast<jarray>(object);
+ const int num_elements = env->GetArrayLength(array);
+ size_t to_copy = num_elements * elementByteSize(type);
+ if (to_copy > dst_size) {
+ throwException(env, kIllegalStateException,
+ "cannot write Java array of %d bytes to Tensor of %d bytes",
+ to_copy, dst_size);
+ return 0;
+ }
+ switch (type) {
+ case kTfLiteFloat32: {
+ jfloatArray a = static_cast<jfloatArray>(array);
+ jfloat* values = env->GetFloatArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseFloatArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ case kTfLiteInt32: {
+ jintArray a = static_cast<jintArray>(array);
+ jint* values = env->GetIntArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseIntArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ case kTfLiteInt64: {
+ jlongArray a = static_cast<jlongArray>(array);
+ jlong* values = env->GetLongArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseLongArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ case kTfLiteUInt8: {
+ jbyteArray a = static_cast<jbyteArray>(array);
+ jbyte* values = env->GetByteArrayElements(a, nullptr);
+ memcpy(dst, values, to_copy);
+ env->ReleaseByteArrayElements(a, values, JNI_ABORT);
+ return to_copy;
+ }
+ default: {
+ throwException(env, kUnsupportedOperationException,
+ "TensorFlowLite currently supports float (32 bits), "
+ "int (32 bits), byte (8 bits), and long (64 bits), "
+ "support for other types (DataType %d in this case) will "
+ "be added in the future",
+ kTfLiteFloat32, type);
+ return 0;
+ }
+ }
+}
+
+size_t readOneDimensionalArray(JNIEnv* env, TfLiteType data_type,
+ const void* src, size_t src_size, jarray dst) {
+ const int len = env->GetArrayLength(dst);
+ const size_t size = len * elementByteSize(data_type);
+ if (size > src_size) {
+ throwException(
+ env, kIllegalStateException,
+ "cannot fill a Java array of %d bytes with a Tensor of %d bytes", size,
+ src_size);
+ return 0;
+ }
+ switch (data_type) {
+ case kTfLiteFloat32: {
+ jfloatArray float_array = static_cast<jfloatArray>(dst);
+ env->SetFloatArrayRegion(float_array, 0, len,
+ static_cast<const jfloat*>(src));
+ return size;
+ }
+ case kTfLiteInt32: {
+ jintArray int_array = static_cast<jintArray>(dst);
+ env->SetIntArrayRegion(int_array, 0, len, static_cast<const jint*>(src));
+ return size;
+ }
+ case kTfLiteInt64: {
+ jlongArray long_array = static_cast<jlongArray>(dst);
+ env->SetLongArrayRegion(long_array, 0, len,
+ static_cast<const jlong*>(src));
+ return size;
+ }
+ case kTfLiteUInt8: {
+ jbyteArray byte_array = static_cast<jbyteArray>(dst);
+ env->SetByteArrayRegion(byte_array, 0, len,
+ static_cast<const jbyte*>(src));
+ return size;
+ }
+ default: {
+ throwException(env, kIllegalStateException, "invalid DataType(%d)",
+ data_type);
+ }
+ }
+ return 0;
+}
+
+size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src,
+ size_t src_size, int dims_left, jarray dst) {
+ if (dims_left == 1) {
+ return readOneDimensionalArray(env, data_type, src, src_size, dst);
+ } else {
+ jobjectArray ndarray = static_cast<jobjectArray>(dst);
+ int len = env->GetArrayLength(ndarray);
+ size_t size = 0;
+ for (int i = 0; i < len; ++i) {
+ jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
+ size += readMultiDimensionalArray(env, data_type, src + size,
+ src_size - size, dims_left - 1, row);
+ env->DeleteLocalRef(row);
+ if (env->ExceptionCheck()) return size;
+ }
+ return size;
+ }
+}
+
+} // namespace
+
+size_t elementByteSize(TfLiteType data_type) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (data_type) {
+ case kTfLiteFloat32:
+ static_assert(sizeof(jfloat) == 4,
+ "Java float not compatible with kTfLiteFloat");
+ return 4;
+ case kTfLiteInt32:
+ static_assert(sizeof(jint) == 4,
+ "Java int not compatible with kTfLiteInt");
+ return 4;
+ case kTfLiteUInt8:
+ static_assert(sizeof(jbyte) == 1,
+ "Java byte not compatible with kTfLiteUInt8");
+ return 1;
+ case kTfLiteInt64:
+ static_assert(sizeof(jlong) == 8,
+ "Java long not compatible with kTfLiteInt64");
+ return 8;
+ default:
+ return 0;
+ }
+}
+
+size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) {
+ char* buf = static_cast<char*>(env->GetDirectBufferAddress(object));
+ if (!buf) {
+ throwException(env, kIllegalArgumentException,
+ "Input ByteBuffer is not a direct buffer");
+ return 0;
+ }
+ *dst = buf;
+ return dst_size;
+}
+
+size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
+ int dims_left, char** dst, int dst_size) {
+ if (dims_left <= 1) {
+ return writeOneDimensionalArray(env, src, type, *dst, dst_size);
+ } else {
+ jobjectArray ndarray = static_cast<jobjectArray>(src);
+ int len = env->GetArrayLength(ndarray);
+ size_t sz = 0;
+ for (int i = 0; i < len; ++i) {
+ jobject row = env->GetObjectArrayElement(ndarray, i);
+ char* next_dst = *dst + sz;
+ sz += writeMultiDimensionalArray(env, row, type, dims_left - 1, &next_dst,
+ dst_size - sz);
+ env->DeleteLocalRef(row);
+ if (env->ExceptionCheck()) return sz;
+ }
+ return sz;
+ }
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject value) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+ int num_dims = tensor->dims->size;
+ if (num_dims == 0) {
+ throwException(env, kIllegalArgumentException,
+ "copyTo() is not meant for scalar Tensors.");
+ return;
+ }
+ readMultiDimensionalArray(env, tensor->type, tensor->data.raw, tensor->bytes,
+ num_dims, static_cast<jarray>(value));
+}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return 0;
+ return static_cast<jint>(tensor->type);
+}
+
+JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return nullptr;
+ int num_dims = tensor->dims->size;
+ jintArray result = env->NewIntArray(num_dims);
+ jint* dims = env->GetIntArrayElements(result, nullptr);
+ for (int i = 0; i < num_dims; ++i) {
+ dims[i] = static_cast<jint>(tensor->dims->data[i]);
+ }
+ env->ReleaseIntArrayElements(result, dims, 0);
+ return result;
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
new file mode 100644
index 0000000000..3a4910dcc3
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
+
+#include <jni.h>
+#include "tensorflow/contrib/lite/context.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: org_tensorflow_lite_TfLiteTensor
+ * Method:
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_TfLiteTensor
+ * Method:
+ * Signature: (J)[I
+ */
+JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_TfLiteTensor
+ * Method:
+ * Signature: (JLjava/lang/Object;)
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject value);
+
+/*
+ * Finds the size of each data type.
+ */
+size_t elementByteSize(TfLiteType data_type);
+
+/*
+ * Writes data of a ByteBuffer into dest.
+ */
+size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size);
+
+/*
+ * Writes a multi-dimensional array into dest.
+ */
+size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
+ int dims_left, char** dst, int dst_size);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc
new file mode 100644
index 0000000000..2e7f2f5692
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc
@@ -0,0 +1,26 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+
+#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h"
+#include "tensorflow/contrib/lite/version.h"
+
+JNIEXPORT jstring JNICALL
+Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) {
+ char buf[64];
+ snprintf(buf, sizeof(buf), "%d", TFLITE_SCHEMA_VERSION);
+ return env->NewStringUTF(buf);
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
new file mode 100644
index 0000000000..65f8341149
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
+
+#include <jni.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: org_tensorflow_lite_TensorFlowLite
+ * Method: version
+ * Signature: ()Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL
+Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/contrib/lite/java/src/main/native/version_script.lds
new file mode 100644
index 0000000000..38c93dda73
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/version_script.lds
@@ -0,0 +1,11 @@
+VERS_1.0 {
+ # Export JNI symbols.
+ global:
+ Java_*;
+ JNI_OnLoad;
+ JNI_OnUnload;
+
+ # Hide everything else.
+ local:
+ *;
+};
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
new file mode 100644
index 0000000000..cebc944200
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.DataType}. */
+@RunWith(JUnit4.class)
+public final class DataTypeTest {
+
+ @Test
+ public void testElemByteSize() {
+ assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4);
+ assertThat(DataType.INT32.elemByteSize()).isEqualTo(4);
+ assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1);
+ assertThat(DataType.INT64.elemByteSize()).isEqualTo(8);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
new file mode 100644
index 0000000000..a60c63d4b8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -0,0 +1,221 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import java.io.File;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.Interpreter}. */
+@RunWith(JUnit4.class)
+public final class InterpreterTest {
+
+ private static final File MODEL_FILE =
+ new File("third_party/tensorflow/contrib/lite/java/src/testdata/add.bin");
+
+ private static final File MOBILENET_MODEL_FILE =
+ new File("third_party/tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+
+ @Test
+ public void testInterpreter() throws Exception {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ assertThat(interpreter).isNotNull();
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunWithMappedByteBufferModel() throws Exception {
+ Path path = MODEL_FILE.toPath();
+ FileChannel fileChannel =
+ (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+ MappedByteBuffer mappedByteBuffer =
+ fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
+ Interpreter interpreter = new Interpreter(mappedByteBuffer);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ fileChannel.close();
+ }
+
+ @Test
+ public void testRun() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ Float[] oneD = {1.23f, 6.54f, 7.81f};
+ Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ Float[][][][] fourD = {threeD, threeD};
+ Float[][][][] parsedOutputs = new Float[2][8][8][3];
+ try {
+ interpreter.run(fourD, parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;");
+ }
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunWithBoxedInputs() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunForMultipleInputsOutputs() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ interpreter.runForMultipleInputsOutputs(inputs, outputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ }
+
+ @Test
+ public void testMobilenetRun() {
+ // Create a gray image.
+ float[][][][] img = new float[1][224][224][3];
+ for (int i = 0; i < 224; ++i) {
+ for (int j = 0; j < 224; ++j) {
+ img[0][i][j][0] = 0.5f;
+ img[0][i][j][1] = 0.5f;
+ img[0][i][j][2] = 0.5f;
+ }
+ }
+
+ // Allocate memory to receive the output values.
+ float[][] labels = new float[1][1001];
+
+ Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
+ interpreter.run(img, labels);
+ interpreter.close();
+
+ assertThat(labels[0])
+ .usingExactEquality()
+ .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY});
+ }
+
+ @Test
+ public void testRunWithWrongInputType() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ int[] oneD = {4, 3, 9};
+ int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ int[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ try {
+ interpreter.run(fourD, parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ }
+ interpreter.close();
+ }
+
+ @Test
+ public void testRunWithWrongOutputType() {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ try {
+ interpreter.run(fourD, parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Cannot convert an TensorFlowLite tensor with type "
+ + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
+ + " TensorFlowLite type INT32)");
+ }
+ interpreter.close();
+ }
+
+ @Test
+ public void testGetInputIndex() {
+ Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
+ try {
+ interpreter.getInputIndex("WrongInputName");
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "WrongInputName is not a valid name for any input. The indexes of the inputs"
+ + " are {input=0}");
+ }
+ int index = interpreter.getInputIndex("input");
+ assertThat(index).isEqualTo(0);
+ }
+
+ @Test
+ public void testGetOutputIndex() {
+ Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
+ try {
+ interpreter.getOutputIndex("WrongOutputName");
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "WrongOutputName is not a valid name for any output. The indexes of the outputs"
+ + " are {MobilenetV1/Predictions/Softmax=0}");
+ }
+ int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax");
+ assertThat(index).isEqualTo(0);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
new file mode 100644
index 0000000000..9e4724b8e9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -0,0 +1,406 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */
+@RunWith(JUnit4.class)
+public final class NativeInterpreterWrapperTest {
+
+ private static final String FLOAT_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/src/testdata/add.bin";
+
+ private static final String INT_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/src/testdata/int32.bin";
+
+ private static final String LONG_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/src/testdata/int64.bin";
+
+ private static final String BYTE_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/src/testdata/uint8.bin";
+
+ private static final String INVALID_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/src/testdata/invalid.model.tflite";
+
+ @Test
+ public void testConstructor() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ assertThat(wrapper).isNotNull();
+ wrapper.close();
+ }
+
+ @Test
+ public void testConstructorWithInvalidModel() {
+ try {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("Model provided has model identifier ' is ', should be 'TFL3'");
+ }
+ }
+
+ @Test
+ public void testRunWithFloat() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, -6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ outputs[0].copyTo(parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, -19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithInt() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
+ int[] oneD = {3, 7, -4};
+ int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ int[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ int[][][][] parsedOutputs = new int[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ int[] outputOneD = parsedOutputs[0][0][0];
+ int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithLong() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
+ long[] oneD = {-892834092L, 923423L, 2123918239018L};
+ long[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ long[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ long[][][][] parsedOutputs = new long[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ long[] outputOneD = parsedOutputs[0][0][0];
+ long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
+ -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByte() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ byte[] oneD = {(byte) 0xe0, 0x4f, (byte) 0xd0};
+ byte[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ byte[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ byte[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ int[] inputDims = {2, 8, 8, 3};
+ wrapper.resizeInput(0, inputDims);
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ byte[][][][] parsedOutputs = new byte[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ byte[] outputOneD = parsedOutputs[0][0][0];
+ byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
+ (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByteBufferHavingBytes() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 8 * 8 * 3);
+ bbuf.order(ByteOrder.nativeOrder());
+ bbuf.rewind();
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ bbuf.put((byte) 0xe0);
+ bbuf.put((byte) 0x4f);
+ bbuf.put((byte) 0xd0);
+ }
+ }
+ }
+ Object[] inputs = {bbuf};
+ int[] inputDims = {2, 8, 8, 3};
+ wrapper.resizeInput(0, inputDims);
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ byte[][][][] parsedOutputs = new byte[2][4][4][12];
+ outputs[0].copyTo(parsedOutputs);
+ byte[] outputOneD = parsedOutputs[0][0][0];
+ byte[] expected = {
+ (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
+ (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
+ };
+ assertThat(outputOneD).isEqualTo(expected);
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByteBufferHavingFloats() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ ByteBuffer bbuf = ByteBuffer.allocateDirect(4 * 8 * 8 * 3 * 4);
+ bbuf.order(ByteOrder.nativeOrder());
+ bbuf.rewind();
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ bbuf.putFloat(1.23f);
+ bbuf.putFloat(-6.54f);
+ bbuf.putFloat(7.81f);
+ }
+ }
+ }
+ Object[] inputs = {bbuf};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ }
+ int[] inputDims = {4, 8, 8, 3};
+ wrapper.resizeInput(0, inputDims);
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ outputs[0].copyTo(parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, -19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithByteBufferHavingWrongSize() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ bbuf.order(ByteOrder.nativeOrder());
+ Object[] inputs = {bbuf};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputType() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ int[] oneD = {4, 3, 9};
+ int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ int[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunAfterClose() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ wrapper.close();
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
+ }
+ }
+
+ @Test
+ public void testRunWithEmptyInputs() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ try {
+ Object[] inputs = {};
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("Invalid inputs. Inputs should not be null or empty.");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputSize() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD, fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputNumOfDims() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ Object[] inputs = {threeD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testRunWithWrongInputDims() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testNumElements() {
+ int[] shape = {2, 3, 4};
+ int num = NativeInterpreterWrapper.numElements(shape);
+ assertThat(num).isEqualTo(24);
+ shape = null;
+ num = NativeInterpreterWrapper.numElements(shape);
+ assertThat(num).isEqualTo(0);
+ }
+
+ @Test
+ public void testIsNonEmtpyArray() {
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
+ int[] emptyArray = {};
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
+ int[] validArray = {9, 5, 2, 1};
+ assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
+ }
+
+ @Test
+ public void testDataTypeOf() {
+ float[] testEmtpyArray = {};
+ DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[] testFloatArray = {0.783f, 0.251f};
+ dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
+ dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ try {
+ double[] testDoubleArray = {0.783, 0.251};
+ NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
+ }
+ try {
+ Float[] testBoxedArray = {0.783f, 0.251f};
+ NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
+ }
+ }
+
+ @Test
+ public void testNumDimensions() {
+ int scalar = 1;
+ assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
+ int[][] array = {{2, 4}, {1, 9}};
+ assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
+ try {
+ int[] emptyArray = {};
+ NativeInterpreterWrapper.numDimensions(emptyArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("array lengths cannot be 0.");
+ }
+ }
+
+ @Test
+ public void testFillShape() {
+ int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
+ int num = NativeInterpreterWrapper.numDimensions(array);
+ int[] shape = new int[num];
+ NativeInterpreterWrapper.fillShape(array, 0, shape);
+ assertThat(num).isEqualTo(3);
+ assertThat(shape[0]).isEqualTo(2);
+ assertThat(shape[1]).isEqualTo(3);
+ assertThat(shape[2]).isEqualTo(1);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java
new file mode 100644
index 0000000000..665c937cb6
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.TensorFlowLite}. */
+@RunWith(JUnit4.class)
+public final class TensorFlowLiteTest {
+
+ @Test
+ public void testVersion() {
+ assertThat(TensorFlowLite.version()).isEqualTo("3");
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
new file mode 100644
index 0000000000..e41e971159
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.lite.Tensor}. */
+@RunWith(JUnit4.class)
+public final class TensorTest {
+
+ private static final String MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/src/testdata/add.bin";
+
+ private NativeInterpreterWrapper wrapper;
+ private long nativeHandle;
+
+ @Before
+ public void setUp() {
+ wrapper = new NativeInterpreterWrapper(MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ nativeHandle = outputs[0].nativeHandle;
+ }
+
+ @After
+ public void tearDown() {
+ wrapper.close();
+ }
+
+ @Test
+ public void testFromHandle() throws Exception {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ assertThat(tensor).isNotNull();
+ int[] expectedShape = {2, 8, 8, 3};
+ assertThat(tensor.shapeCopy).isEqualTo(expectedShape);
+ assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32);
+ }
+
+ @Test
+ public void testCopyTo() {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ tensor.copyTo(parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+
+ @Test
+ public void testCopyToWrongType() {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ try {
+ tensor.copyTo(parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Cannot convert an TensorFlowLite tensor with type "
+ + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite "
+ + "type INT32)");
+ }
+ }
+
+ @Test
+ public void testCopyToWrongShape() {
+ Tensor tensor = Tensor.fromHandle(nativeHandle);
+ float[][][][] parsedOutputs = new float[1][8][8][3];
+ try {
+ tensor.copyTo(parsedOutputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Shape of output target [1, 8, 8, 3] does not match "
+ + "with the shape of the Tensor [2, 8, 8, 3].");
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
new file mode 100644
index 0000000000..2b4f37bc6c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
@@ -0,0 +1,30 @@
+# Description:
+# Internal helper function to test TF Lite API.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+android_library(
+ name = "testhelper",
+ srcs = glob(
+ [
+ "*.java",
+ ],
+ ),
+ deps = [
+ "//tensorflow/contrib/lite/java:tensorflowlite_java",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
new file mode 100644
index 0000000000..8660cabf70
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** A helper class for internal tests. */
+public class TestHelper {
+
+ /**
+ * Turns on/off NNAPI of an {@code Interpreter}.
+ *
+ * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+ * IllegalArgumentException} will be thrown.
+ * @param useNNAPI a boolean value indicating to turn on or off NNAPI.
+ */
+ public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) {
+ if (interpreter != null && interpreter.wrapper != null) {
+ interpreter.wrapper.setUseNNAPI(useNNAPI);
+ } else {
+ throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
new file mode 100644
index 0000000000..bbbfa3e741
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -0,0 +1,408 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+tf_cc_test(
+ name = "optional_tensor_test",
+ size = "small",
+ srcs = ["optional_tensor_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "test_util",
+ testonly = 1,
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/core:lib",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "gemm_support",
+ srcs = [
+ "gemm_support.cc",
+ ],
+ hdrs = [
+ "gemm_support.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":op_macros",
+ "//tensorflow/contrib/lite:context",
+ "@gemmlowp//:gemmlowp",
+ ],
+)
+
+cc_library(
+ name = "activation_functor",
+ hdrs = [
+ "activation_functor.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ],
+)
+
+cc_library(
+ name = "op_macros",
+ hdrs = [
+ "op_macros.h",
+ ],
+)
+
+cc_library(
+ name = "builtin_ops",
+ srcs = [
+ "activations.cc",
+ "add.cc",
+ "basic_rnn.cc",
+ "concatenation.cc",
+ "conv.cc",
+ "depthwise_conv.cc",
+ "embedding_lookup.cc",
+ "embedding_lookup_sparse.cc",
+ "fully_connected.cc",
+ "hashtable_lookup.cc",
+ "kernel_util.cc",
+ "l2norm.cc",
+ "local_response_norm.cc",
+ "lsh_projection.cc",
+ "lstm.cc",
+ "mul.cc",
+ "pooling.cc",
+ "register.cc",
+ "reshape.cc",
+ "resize_bilinear.cc",
+ "skip_gram.cc",
+ "space_to_depth.cc",
+ "svdf.cc",
+ ],
+ hdrs = [
+ "kernel_util.h",
+ "padding.h",
+ "register.h",
+ ],
+ # Suppress warnings that are introduced by Eigen Tensor.
+ copts = tflite_copts() + [
+ "-Wno-error=reorder",
+ ] + select({
+ "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
+ "//conditions:default": [
+ ],
+ }),
+ deps = [
+ ":activation_functor",
+ ":op_macros",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:gemm_support",
+ "//tensorflow/contrib/lite/kernels/internal:optimized",
+ "//tensorflow/contrib/lite/kernels/internal:optimized_base",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "@farmhash_archive//:farmhash",
+ ],
+)
+
+tf_cc_test(
+ name = "activations_test",
+ size = "small",
+ srcs = ["activations_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "add_test",
+ size = "small",
+ srcs = ["add_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "concatenation_test",
+ size = "small",
+ srcs = ["concatenation_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "conv_test",
+ size = "small",
+ srcs = ["conv_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "depthwise_conv_test",
+ size = "small",
+ srcs = ["depthwise_conv_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "basic_rnn_test",
+ size = "small",
+ srcs = ["basic_rnn_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "l2norm_test",
+ size = "small",
+ srcs = ["l2norm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "mul_test",
+ size = "small",
+ srcs = ["mul_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "reshape_test",
+ size = "small",
+ srcs = ["reshape_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "resize_bilinear_test",
+ size = "small",
+ srcs = ["resize_bilinear_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "svdf_test",
+ size = "small",
+ srcs = ["svdf_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "embedding_lookup_test",
+ size = "small",
+ srcs = ["embedding_lookup_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "embedding_lookup_sparse_test",
+ size = "small",
+ srcs = ["embedding_lookup_sparse_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "fully_connected_test",
+ size = "small",
+ srcs = ["fully_connected_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "local_response_norm_test",
+ size = "small",
+ srcs = ["local_response_norm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "pooling_test",
+ size = "small",
+ srcs = ["pooling_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "softmax_test",
+ size = "small",
+ srcs = ["softmax_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "lsh_projection_test",
+ size = "small",
+ srcs = ["lsh_projection_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "hashtable_lookup_test",
+ size = "small",
+ srcs = ["hashtable_lookup_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "lstm_test",
+ size = "small",
+ srcs = ["lstm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "skip_gram_test",
+ size = "small",
+ srcs = ["skip_gram_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "space_to_depth_test",
+ size = "small",
+ srcs = ["space_to_depth_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
new file mode 100644
index 0000000000..cfb3369e99
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+
+// Dynamic (non-fused) activation functor. perhaps it is worth having
+// template instantiation?
+// TODO(aselle): Make this more efficient by pulling the switch to conv_eval
+// using template inlining.
+class ActivationFunctor {
+ public:
+ explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {}
+
+ float operator()(float a) const {
+ switch (act_) {
+ case kTfLiteActNone:
+ return a;
+ case kTfLiteActRelu:
+ return a < 0.f ? 0.f : a;
+ case kTfLiteActRelu6:
+ return std::max(0.f, std::min(a, 6.f));
+ case kTfLiteActTanh:
+ return std::tanh(a);
+ case kTfLiteActSigmoid:
+ return 1.0f / (1.0f + std::exp(-a));
+ default:
+ // TODO(aselle): More informative fatal error!
+ exit(1);
+ }
+ }
+
+ private:
+ TfLiteFusedActivation act_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
new file mode 100644
index 0000000000..7ab60a33e5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -0,0 +1,389 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace activations {
+
+struct OpData {
+ int32_t input_multiplier = 0;
+ int input_left_shift = 0;
+ int32_t input_range_radius = 0;
+ int diff_min = 0;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+
+ static constexpr int kInputIntegerBits = 4;
+
+ const double input_real_multiplier =
+ input->params.scale *
+ static_cast<double>(1 << (31 - kInputIntegerBits));
+
+ QuantizeMultiplierGreaterThanOne(input_real_multiplier,
+ &data->input_multiplier,
+ &data->input_left_shift);
+ data->input_range_radius =
+ CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ TF_LITE_ENSURE(context,
+ NumDimensions(input) == 2 || NumDimensions(input) == 4);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+
+ static const int kScaledDiffIntegerBits = 5;
+
+ tflite::PreprocessSoftmaxScaling(
+ params->beta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift);
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::max(0.f, *in);
+ return kTfLiteOk;
+ }
+ break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) {
+ *out = std::min(std::max(-1.f, *in), 1.f);
+ }
+ return kTfLiteOk;
+ } break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f);
+ return kTfLiteOk;
+ }
+ break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::tanh(*in);
+ return kTfLiteOk;
+ }
+ break;
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+}
+
+// Sigmoid is also know as "Logistic".
+TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = input->bytes / sizeof(float);
+ float* in = input->data.f;
+ float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in));
+ break;
+ }
+ case kTfLiteUInt8: {
+ optimized_ops::Logistic(
+ GetTensorData<uint8_t>(input), GetTensorDims(input),
+ input->params.zero_point, data->input_range_radius,
+ data->input_multiplier, data->input_left_shift,
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+ break;
+ }
+ default:
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Takes a 2D tensor and perform softmax along the second dimension.
+void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ float* in = input->data.f;
+ float* out = output->data.f;
+ TF_LITE_ASSERT(input_size > 0);
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Find the max coeff.
+ float max_coeff = in[0];
+ for (int i = 1; i < input_size; i++) {
+ if (in[i] > max_coeff) max_coeff = in[i];
+ }
+
+ // Compute the normalized sum of exps.
+ float exp_sum = 0.0;
+ for (int i = 0; i < input_size; i++) {
+ out[i] = std::exp((in[i] - max_coeff) * params->beta);
+ exp_sum += out[i];
+ }
+
+ // Divide by the sum of exps.
+ float reciprocal_sum_exp = 1.f / exp_sum;
+ for (int i = 0; i < input_size; i++) {
+ out[i] *= reciprocal_sum_exp;
+ }
+
+ // Advance in and out pointers for the next batch.
+ in += input_size;
+ out += input_size;
+ }
+}
+
+void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 2D
+ // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
+ // 1, 1, Y) shape.
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ optimized_ops::Softmax(GetTensorData<uint8_t>(input),
+ GetTensorDims({batch_size, 1, 1, input_size}),
+ data->input_multiplier, data->input_left_shift,
+ data->diff_min, GetTensorData<uint8_t>(output),
+ GetTensorDims({batch_size, 1, 1, input_size}));
+}
+
+// Takes a 4D tensor and perform softmax along the forth dimension.
+void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ optimized_ops::Softmax(GetTensorData<float>(input), GetTensorDims(input),
+ params->beta, GetTensorData<float>(output),
+ GetTensorDims(output));
+}
+
+void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorDims(input),
+ data->input_multiplier, data->input_left_shift,
+ data->diff_min, GetTensorData<uint8_t>(output),
+ GetTensorDims(output));
+}
+
+TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+
+ // TODO(ahentz): consider an implementation that works for many (all?)
+ // dimensions.
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ if (NumDimensions(input) == 2) {
+ Softmax2DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ context->ReportError(context,
+ "Only 2D and 4D tensors supported currently.");
+ return kTfLiteError;
+ }
+ case kTfLiteUInt8: {
+ if (NumDimensions(input) == 2) {
+ Softmax2DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ context->ReportError(context,
+ "Only 2D and 4D tensors supported currently.");
+ return kTfLiteError;
+ }
+ default:
+ context->ReportError(context,
+ "Only float32 and uint8_t supported currently.");
+ return kTfLiteError;
+ }
+}
+
+} // namespace activations
+
+TfLiteRegistration* Register_RELU() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::ReluEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_RELU1() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::Relu1Eval};
+ return &r;
+}
+
+TfLiteRegistration* Register_RELU6() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::Relu6Eval};
+ return &r;
+}
+
+TfLiteRegistration* Register_TANH() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::GenericPrepare,
+ activations::TanhEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOGISTIC() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::SigmoidPrepare,
+ activations::SigmoidEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_SOFTMAX() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::SoftmaxPrepare,
+ activations::SoftmaxEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
new file mode 100644
index 0000000000..f10aee7017
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -0,0 +1,323 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+ // Most activations don't take any options, so this constructor works for
+ // them.
+ BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
+ input_ = AddInput(input);
+ if (input.type == TensorType_UINT8) {
+ output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
+ } else {
+ output_ = AddOutput({input.type, {}});
+ }
+ SetBuiltinOp(type, BuiltinOptions_NONE, 0);
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ // A dedicated constructor for SOFTMAX, which does some options.
+ BaseActivationsOpModel(float softmax_beta, TensorData input) {
+ input_ = AddInput(input);
+ if (input.type == TensorType_UINT8) {
+ output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
+ } else {
+ output_ = AddOutput({input.type, {}});
+ }
+ SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
+ CreateSoftmaxOptions(builder_, softmax_beta).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+// TODO(ahentz): I don't quite understand the tradeoffs in the quantized
+// implementation of sigmoid and software, but a tolerance of twice the output
+// scale seems reasonable. We might want to change this if we have a better
+// theoretical bound.
+const float kQuantizedTolerance = 2 * (1. / 256);
+
+class QuantizedActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatActivationsOpTest, Relu) {
+ FloatActivationsOpModel m(BuiltinOperator_RELU,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0, 0, 2, 4, //
+ 3, 0, 10, 1, //
+ }));
+}
+
+TEST(FloatActivationsOpTest, Relu1) {
+ FloatActivationsOpModel m(BuiltinOperator_RELU1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -2.0, 1.1, -0.1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -1.0, 1.0, -0.1, //
+ }));
+}
+
+TEST(FloatActivationsOpTest, Relu6) {
+ FloatActivationsOpModel m(BuiltinOperator_RELU6,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0, 0, 2, 4, //
+ 3, 0, 6, 1, //
+ }));
+}
+
+TEST(FloatActivationsOpTest, Tanh) {
+ FloatActivationsOpModel m(BuiltinOperator_TANH,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0, -0.9999877, 0.9640275, 0.999329, //
+ 0.99505475, -0.9640275, 1, 0.7615941, //
+ })));
+}
+
+TEST(FloatActivationsOpTest, Sigmoid) {
+ FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Sigmoid) {
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ },
+ kQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188}));
+}
+
+TEST(FloatActivationsOpTest, Softmax4D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax4D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
+TEST(FloatActivationsOpTest, Softmax2D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax2D) {
+ QuantizedActivationsOpModel m(0.1,
+ /*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_UINT8, {4, 2}, -10, 10});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
new file mode 100644
index 0000000000..0e10a249ab
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace add {
+
+// This file has three implementation of Add.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
+ for (int i = 0; i < NumDimensions(input1); ++i) {
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
+ SizeOfDimension(input2, i));
+ }
+
+ TF_LITE_ENSURE_EQ(context, input1->type, output->type);
+ TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_ADD(type) \
+ type::Add(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops);
+ } else {
+ TF_LITE_ADD(optimized_ops);
+ }
+#undef TF_LITE_ADD
+}
+
+template <KernelType kernel_type>
+void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ auto input1_offset = -input1->params.zero_point;
+ auto input2_offset = -input2->params.zero_point;
+ auto output_offset = output->params.zero_point;
+ const int left_shift = 20;
+ const double twice_max_input_scale =
+ 2 * std::max(input1->params.scale, input2->params.scale);
+ const double real_input1_multiplier =
+ input1->params.scale / twice_max_input_scale;
+ const double real_input2_multiplier =
+ input2->params.scale / twice_max_input_scale;
+ const double real_output_multiplier =
+ twice_max_input_scale / ((1 << left_shift) * output->params.scale);
+
+ int32 input1_multiplier;
+ int input1_shift;
+ QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
+ &input1_shift);
+ int32 input2_multiplier;
+ int input2_shift;
+ QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
+ &input2_shift);
+ int32 output_multiplier;
+ int output_shift;
+ QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
+ &output_shift);
+
+ int32 output_activation_min, output_activation_max;
+ CalculateActivationRangeUint8(params->activation, output,
+ &output_activation_min, &output_activation_max);
+
+#define TF_LITE_ADD(type) \
+ type::BroadcastAdd( \
+ left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ input1_offset, input1_multiplier, input1_shift, \
+ GetTensorData<uint8_t>(input2), GetTensorDims(input2), input2_offset, \
+ input2_multiplier, input2_shift, output_offset, output_multiplier, \
+ output_shift, output_activation_min, output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops);
+ } else {
+ TF_LITE_ADD(optimized_ops);
+ }
+#undef TF_LITE_ADD
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+ EvalAddFloat<kernel_type>(context, node, params, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8) {
+ EvalAddQuantized<kernel_type>(context, node, params, input1, input2,
+ output);
+ } else {
+ context->ReportError(context,
+ "Inputs and outputs not all float|unit8 types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace add
+
+TfLiteRegistration* Register_ADD_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ add::Eval<add::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_ADD_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ add::Eval<add::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_ADD_NEON_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ add::Eval<add::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_ADD() {
+#ifdef USE_NEON
+ return Register_ADD_NEON_OPT();
+#else
+ return Register_ADD_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc
new file mode 100644
index 0000000000..8e12a837c4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/add_test.cc
@@ -0,0 +1,171 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseAddOpModel : public SingleOpModel {
+ public:
+ BaseAddOpModel(const TensorData& input, const TensorData& output,
+ ActivationFunctionType activation_type) {
+ input1_ = AddInput(input);
+ input2_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
+ CreateAddOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+class FloatAddOpModel : public BaseAddOpModel {
+ public:
+ using BaseAddOpModel::BaseAddOpModel;
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedAddOpModel : public BaseAddOpModel {
+ public:
+ using BaseAddOpModel::BaseAddOpModel;
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// for quantized Add, the error shouldn't exceed 2*step
+float GetTolerance(int min, int max) {
+ float kQuantizedStep = (max - min) / 255.0;
+ float kQuantizedTolerance = 2.0 * kQuantizedStep;
+ return kQuantizedTolerance;
+}
+
+TEST(FloatAddOpModel, NoActivation) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
+}
+
+TEST(FloatAddOpModel, ActivationRELU1) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0}));
+}
+
+TEST(FloatAddOpModel, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<std::initializer_list<float>> inputs1 = {
+ {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}};
+ std::vector<std::initializer_list<float>> inputs2 = {
+ {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}};
+ std::vector<std::initializer_list<float>> results = {
+ {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
+ for (int i = 0; i < inputs1.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]);
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[i]);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ results[i], kQuantizedTolerance)))
+ << "With test number " << i;
+ }
+}
+
+TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<std::initializer_list<float>> inputs1 = {{-0.8, 0.2, 0.9, 0.7},
+ {-0.8, 0.2, 0.7, 0.3}};
+ std::vector<std::initializer_list<float>> inputs2 = {{0.6, 0.4, 0.9, -0.8},
+ {0.6, 0.4, -0.8, 0.5}};
+ std::vector<std::initializer_list<float>> results = {{-0.2, 0.6, 1.0, -0.1},
+ {-0.2, 0.6, -0.1, 0.8}};
+ for (int i = 0; i < inputs1.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0},
+ ActivationFunctionType_RELU1);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]);
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[i]);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ results[i], kQuantizedTolerance)))
+ << "With test number " << i;
+ }
+}
+
+TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) {
+ float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0},
+ {TensorType_UINT8, {}, -3.0, 3.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1},
+ kQuantizedTolerance)))
+ << "With shape number " << i;
+ }
+}
+
+} // namespace
+} // namespace tflite
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
new file mode 100644
index 0000000000..3cee43c68b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace rnn {
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kRecurrentWeightsTensor = 2;
+constexpr int kBiasTensor = 3;
+constexpr int KHiddenStateTensor = 0;
+constexpr int kOutputTensor = 1;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* input_weights =
+ &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* recurrent_weights =
+ &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
+ TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
+ TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
+ TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
+
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+
+ // Resize state.
+ TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
+ hidden_state_size_array->data[0] = batch_size;
+ hidden_state_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
+ hidden_state_size_array));
+
+ // Mark hidden state as a persistent tensor.
+ hidden_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output,
+ output_size_array));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* input_weights =
+ &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* recurrent_weights =
+ &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
+ TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+
+ // Initialize the pointer bias.
+ const float* bias_ptr = bias->data.f;
+
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int input_weights_stride = input_weights->dims->data[1];
+ const int recurrent_weights_stride = recurrent_weights->dims->data[1];
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to input, output and bias.
+ const float* input_ptr_batch = input->data.f + b * input_size;
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
+
+ // Output = bias
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = bias_ptr[o];
+ }
+
+ // Output += input * input_weights
+ for (int o = 0; o < num_units; o++) {
+ for (int i = 0; i < input_size; i++) {
+ output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
+ }
+ input_weights_ptr += input_weights_stride;
+ }
+
+ // Output += recurrent_weights * hidden_state
+ for (int o = 0; o < num_units; o++) {
+ for (int h = 0; h < num_units; h++) {
+ output_ptr_batch[o] +=
+ hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+ }
+ recurrent_weights_ptr += recurrent_weights_stride;
+ }
+
+ // Output = activation(Output) and update hidden_state
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] =
+ (ActivationFunctor(params->activation))(output_ptr_batch[o]);
+ hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace rnn
+
+TfLiteRegistration* Register_RNN() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ rnn::Prepare, rnn::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
new file mode 100644
index 0000000000..dfa75655bc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -0,0 +1,267 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite RNN op.
+
+#include <vector>
+#include <iomanip>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+static float rnn_input[] = {
+ 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
+ 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
+ -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
+ 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
+ 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
+ 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
+ -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
+ -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
+ 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
+ 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
+ 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
+ -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
+ 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
+ -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
+ -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
+ -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
+ 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
+ -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
+ -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
+ 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
+ -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
+ 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
+ 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
+ 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
+ -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
+ 0.93455386, -0.6324693, -0.083922029};
+
+static float rnn_golden_output[] = {
+ 0.496726, 0, 0.965996, 0, 0.0584254, 0,
+ 0, 0.12315, 0, 0, 0.612266, 0.456601,
+ 0, 0.52286, 1.16099, 0.0291232,
+
+ 0, 0, 0.524901, 0, 0, 0,
+ 0, 1.02116, 0, 1.35762, 0, 0.356909,
+ 0.436415, 0.0355727, 0, 0,
+
+ 0, 0, 0, 0.262335, 0, 0,
+ 0, 1.33992, 0, 2.9739, 0, 0,
+ 1.31914, 2.66147, 0, 0,
+
+ 0.942568, 0, 0, 0, 0.025507, 0,
+ 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
+ 0.8158, 1.21805, 0.586239, 0.25427,
+
+ 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
+ 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
+ 0, 1.22031, 1.30117, 0.495867,
+
+ 0.222187, 0, 0.72725, 0, 0.767003, 0,
+ 0, 0.147835, 0, 0, 0, 0.608758,
+ 0.469394, 0.00720298, 0.927537, 0,
+
+ 0.856974, 0.424257, 0, 0, 0.937329, 0,
+ 0, 0, 0.476425, 0, 0.566017, 0.418462,
+ 0.141911, 0.996214, 1.13063, 0,
+
+ 0.967899, 0, 0, 0, 0.0831304, 0,
+ 0, 1.00378, 0, 0, 0, 1.44818,
+ 1.01768, 0.943891, 0.502745, 0,
+
+ 0.940135, 0, 0, 0, 0, 0,
+ 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
+ 1.30225, 1.59644, 0.70222, 0,
+
+ 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
+ 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
+ 0.0454298, 0.300267, 0.562784, 0.395095,
+
+ 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
+ 0, 0, 0, 0.735363, 0.0759267, 1.91017,
+ 0.941888, 0, 0, 0,
+
+ 0, 0, 1.5909, 0, 0, 0,
+ 0, 0.5755, 0, 0.184687, 0, 1.56296,
+ 0.625285, 0, 0, 0,
+
+ 0, 0, 0.0857888, 0, 0, 0,
+ 0, 0.488383, 0.252786, 0, 0, 0,
+ 1.02817, 1.85665, 0, 0,
+
+ 0.00981836, 0, 1.06371, 0, 0, 0,
+ 0, 0, 0, 0.290445, 0.316406, 0,
+ 0.304161, 1.25079, 0.0707152, 0,
+
+ 0.986264, 0.309201, 0, 0, 0, 0,
+ 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
+ 0.524981, 1.92076, 2.07013, 0.333244,
+
+ 0.415153, 0.210318, 0, 0, 0, 0,
+ 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
+ 0.628881, 3.58099, 1.49974, 0
+};
+
+class RNNOpModel : public SingleOpModel {
+ public:
+ RNNOpModel(int batches, int units, int size)
+ : batches_(batches), units_(units), input_size_(size) {
+ input_ = AddInput(TensorType_FLOAT32);
+ weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_weights_ = AddInput(TensorType_FLOAT32);
+ bias_ = AddInput(TensorType_FLOAT32);
+ hidden_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
+ CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
+ BuildInterpreter({{batches_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetWeights(std::initializer_list<float> f) {
+ PopulateTensor(weights_, f);
+ }
+
+ void SetRecurrentWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_weights_, f);
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ void ResetHiddenState() {
+ const int zero_buffer_size = units_ * batches_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(hidden_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ private:
+ int input_;
+ int weights_;
+ int recurrent_weights_;
+ int bias_;
+ int hidden_state_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+};
+
+TEST(FullyConnectedOpTest, BlackBoxTest) {
+ RNNOpModel rnn(2, 16, 8);
+ rnn.SetWeights(
+ {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818});
+
+ rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
+ -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
+ 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
+ -0.37609905});
+
+ rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1});
+
+ rnn.ResetHiddenState();
+ const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
+ (rnn.input_size() * rnn.num_batches());
+
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(rnn.input_size(), batch_start, batch_end);
+
+ rnn.Invoke();
+
+ float* golden_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_end = golden_start + rnn.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
new file mode 100644
index 0000000000..9e7a1233da
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -0,0 +1,200 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace concatenation {
+
+// This file has two implementation of Concatenation.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
+ int axis = params->axis;
+ int num_inputs = node->inputs->size;
+
+ // The number of dimensions of the input tensors must match, and all
+ // dimensions except 'axis' must be equal.
+ TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]];
+ TfLiteType input_type = t0->type;
+ TF_LITE_ENSURE(context, axis >= 0);
+ TF_LITE_ENSURE(context, axis < t0->dims->size);
+
+ // TODO(ahentz): These are limitations of our implementation that could be
+ // removed with a bit of effort.
+ TF_LITE_ENSURE(context, t0->dims->size <= 4);
+ TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
+ TF_LITE_ENSURE(context,
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+
+ // Output dimensions will match input dimensions, except 'axis', which
+ // will be the sum of inputs
+ int sum_axis = t0->dims->data[axis];
+ for (int i = 1; i < num_inputs; ++i) {
+ TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
+ TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
+ TF_LITE_ENSURE_EQ(context, t->type, input_type);
+ if (input_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale);
+ }
+ for (int d = 0; d < t0->dims->size; ++d) {
+ if (d == axis) {
+ sum_axis += t->dims->data[axis];
+ } else {
+ TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
+ }
+ }
+ }
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
+ for (int d = 0; d < t0->dims->size; ++d) {
+ output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
+ }
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TF_LITE_ENSURE_EQ(context, output->type, input_type);
+ if (input_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point,
+ t0->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+class VectorOfInputs {
+ public:
+ VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) {
+ int num_inputs = inputs.size;
+
+ all_data_.reserve(num_inputs);
+ all_dims_.reserve(num_inputs);
+ all_dims_ptr_.reserve(num_inputs);
+
+ for (int i = 0; i < num_inputs; ++i) {
+ TfLiteTensor* input = &context.tensors[inputs.data[i]];
+ all_data_.push_back(GetTensorData<T>(input));
+ all_dims_.push_back(GetTensorDims(input));
+ }
+
+ // Taking the pointer from inside a std::vector is only OK if the vector is
+ // never modified, so we populate all_dims in the previous loop and then we
+ // are free to grab iterators here.
+ for (int i = 0; i < num_inputs; ++i) {
+ all_dims_ptr_.push_back(&all_dims_[i]);
+ }
+ }
+ const T* const* data() const { return all_data_.data(); }
+ const Dims<4>* const* dims() const { return all_dims_ptr_.data(); }
+
+ private:
+ std::vector<T*> all_data_;
+ std::vector<Dims<4>> all_dims_;
+ std::vector<Dims<4>*> all_dims_ptr_;
+};
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+
+// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
+// allocate and populate these during Prepare().
+// TODO(ycling): Activation function parameter is ignored. For now we dont have
+// a model with a Concatenation with fused activation function.
+#define TF_LITE_CONCATENATION(type, scalar) \
+ VectorOfInputs<scalar> all_inputs(*context, *node->inputs); \
+ type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \
+ RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \
+ all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
+ GetTensorDims(output))
+
+ switch (output->type) { // Already know in/outtypes are same.
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, float);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, uint8_t);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, uint8_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Only float32 and uint8 are currently supported.");
+ return kTfLiteError;
+ }
+
+#undef TF_LITE_CONCATENATION
+
+ return kTfLiteOk;
+}
+
+#undef TF_LITE_MACRO_DISPATCH
+
+} // namespace concatenation
+
+TfLiteRegistration* Register_CONCATENATION_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, concatenation::Prepare,
+ concatenation::Eval<concatenation::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, concatenation::Prepare,
+ concatenation::Eval<concatenation::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONCATENATION() {
+ // TODO(ahentz): It turns out the two versions of Concatenation are almost
+ // identical, so we should consider removing one.
+ return Register_CONCATENATION_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc
new file mode 100644
index 0000000000..94e5b2acdc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc
@@ -0,0 +1,162 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseConcatenationOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, axis, input
+ // dimensions.
+ BaseConcatenationOpModel(const TensorData& input_template, int axis,
+ int num_inputs) {
+ std::vector<std::vector<int>> all_input_shapes;
+ for (int i = 0; i < num_inputs; ++i) {
+ all_input_shapes.push_back(input_template.shape);
+ AddInput(input_template);
+ }
+ output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
+ input_template.max});
+ SetBuiltinOp(
+ BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
+ CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
+ .Union());
+ BuildInterpreter(all_input_shapes);
+ }
+
+ protected:
+ int output_;
+};
+
+class ConcatenationOpModel : public BaseConcatenationOpModel {
+ public:
+ using BaseConcatenationOpModel::BaseConcatenationOpModel;
+ void SetInput(int index, std::initializer_list<float> data) {
+ PopulateTensor(index, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
+ public:
+ using BaseConcatenationOpModel::BaseConcatenationOpModel;
+ void SetInput(int index, std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(index, data);
+ }
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(ConcatenationOpTest, ThreeDimensionalOneInput) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
+ /*num_inputs=*/1);
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7}));
+}
+
+TEST(ConcatenationOpTest, OneTrivialInput) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0,
+ /*num_inputs=*/1);
+ m0.SetInput(0, {5.0f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5));
+}
+
+TEST(ConcatenationOpTest, TwoDimensionalOneInput) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0,
+ /*num_inputs=*/1);
+ m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(ConcatenationOpTest, TwoInputsTwoAxis) {
+ // We will concatenate two tensors along different dimensions.
+ auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
+
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0,
+ /*num_inputs=*/2);
+ m0.SetInput(0, tensor0);
+ m0.SetInput(1, tensor1);
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+
+ ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1,
+ /*num_inputs=*/2);
+ m1.SetInput(0, tensor0);
+ m1.SetInput(1, tensor1);
+ m1.Invoke();
+ EXPECT_THAT(m1.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
+TEST(ConcatenationOpTest, FourInputs) {
+ ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
+ /*num_inputs=*/4);
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
+ m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
+ m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetOutput(),
+ ElementsAreArray({
+ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
+ 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
+ }));
+}
+
+TEST(ConcatenationOpTest, FourInputsQuantized) {
+ QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
+ /*axis=*/2,
+ /*num_inputs=*/4);
+
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
+ m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
+ m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
+ 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
+ })));
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
+ 137, 157, 138, 158, 139, 159, 140, 160, //
+ 167, 197, 168, 198, 169, 199, 170, 200, //
+ }));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
new file mode 100644
index 0000000000..c75c04baea
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -0,0 +1,425 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace conv {
+
+// This file has three implementation of Conv.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+struct OpData {
+ // IDs are the arbitrary identifiers used by TF Lite to identify and access
+ // memory buffers.
+ int im2col_id;
+ int hwcn_weights_id;
+
+ TfLitePaddingValues padding;
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multipler plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+ // Indexes are the offset to the memory buffer in the array used to keep track
+ // of the allocated temporaries.
+ int32_t im2col_index;
+ int32_t hwcn_weights_index;
+ bool need_hwcn_weights;
+ bool have_weights_been_transposed;
+ bool need_im2col;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to use as scratch space for im2col, and
+ // to carry information from Prepare() to Eval().
+ auto* data = new OpData;
+ context->AddTensors(context, 1, &data->im2col_id);
+ context->AddTensors(context, 1, &data->hwcn_weights_id);
+ gemm_support::IncrementUsageCounter(context);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ gemm_support::DecrementUsageCounter(context);
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+// Naive implementation of transpose for floats. Could be optimized to be more
+// cache friendly, but for now it's a one-time cost on first run, and we would
+// prefer to remove the need to do this at all eventually.
+void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) {
+ const int rows = output->dims->data[1];
+ const int cols = output->dims->data[0];
+ const float* input_data = GetTensorData<float>(input);
+ float* output_data = GetTensorData<float>(output);
+ for (int i = 0; i < rows; ++i) {
+ for (int j = 0; j < cols; ++j) {
+ const float in_value = input_data[i * cols + j];
+ output_data[j * rows + i] = in_value;
+ }
+ }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ bool hasBias = node->inputs->size == 3;
+ // Check number of inputs/outputs
+ TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ // Check dimensionality of input, filter
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+ TF_LITE_ENSURE_EQ(context, filter->dims->size, 4);
+ // Check input channels matching filter
+ TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]);
+
+ // Check types. (We assume that UINT8 refers to quantized tensors)
+ TfLiteType data_type = input->type;
+ TF_LITE_ENSURE(context,
+ data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, data_type);
+ TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+
+ TfLiteTensor* bias = nullptr;
+
+ // TODO(ahentz): At this point the optimized versions require 'bias'. We can
+ // either change that or document that convolution requires it.
+ TF_LITE_ENSURE(context, hasBias);
+
+ if (hasBias) {
+ bias = &context->tensors[node->inputs->data[2]];
+ if (data_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
+ } else {
+ TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ }
+ TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]);
+ }
+
+ int channels_out = filter->dims->data[0];
+ int width = input->dims->data[2];
+ int height = input->dims->data[1];
+ int filter_width = filter->dims->data[2];
+ int filter_height = filter->dims->data[1];
+ int batches = input->dims->data[0];
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params->padding;
+ auto computeOutSize = [padding](int imageSize, int filterSize,
+ int stride) -> int {
+ return padding == kTfLitePaddingSame
+ ? (imageSize + stride - 1) / stride
+ : padding == kTfLitePaddingValid
+ ? (imageSize - filterSize + stride) / stride
+ : 0;
+ };
+
+ int outWidth = computeOutSize(width, filter_width, params->stride_width);
+ int outHeight = computeOutSize(height, filter_height, params->stride_height);
+
+ data->padding.height =
+ ComputePadding(params->stride_height, height, filter_height, outHeight);
+ data->padding.width =
+ ComputePadding(params->stride_width, width, filter_width, outWidth);
+
+ TF_LITE_ENSURE(context, hasBias);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = batches;
+ output_size->data[1] = outHeight;
+ output_size->data[2] = outWidth;
+ output_size->data[3] = channels_out;
+ auto output_status = context->ResizeTensor(context, output, output_size);
+
+ if (output_status != kTfLiteOk) return output_status;
+
+ // We don't always need to allocate im2col. It is only used in some versions
+ // of the optimized Conv. This test just mimics something that happens inside
+ // optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
+ data->need_im2col =
+ (params->stride_width != 1 || params->stride_height != 1 ||
+ filter_width != 1 || filter_height != 1);
+ // If we're using the optimized multithreaded EigenTensor implementation of
+ // convolution, it expects the filter weights to be transposed compared to
+ // the normal TF Lite buffer format. Typical TF Lite weights are
+ // [filter_count, filter_height, filter_width, input_depth], but for the float
+ // implementation we need them as [filter_height, filter_width, input_depth,
+ // filter_count]. We get to that format by transposing, and create a temporary
+ // buffer to store the results.
+ // This path is only used for float processing, so only create the buffer if
+ // we're running with that data type.
+ data->need_hwcn_weights = (data_type == kTfLiteFloat32);
+
+ int temporaries_count = 0;
+ if (data->need_im2col) {
+ data->im2col_index = temporaries_count;
+ ++temporaries_count;
+ }
+ if (data->need_hwcn_weights) {
+ data->hwcn_weights_index = temporaries_count;
+ ++temporaries_count;
+ }
+
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(temporaries_count);
+
+ if (data->need_im2col) {
+ node->temporaries->data[data->im2col_index] = data->im2col_id;
+
+ TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(4);
+
+ int input_depth = input->dims->data[3];
+ im2col_size->data[0] = output_size->data[0];
+ im2col_size->data[1] = output_size->data[1];
+ im2col_size->data[2] = output_size->data[2];
+ im2col_size->data[3] = input_depth * filter_height * filter_width;
+
+ TfLiteTensor* im2col =
+ &context->tensors[node->temporaries->data[data->im2col_index]];
+ im2col->type = data_type;
+ im2col->allocation_type = kTfLiteArenaRw;
+ auto im2col_status = context->ResizeTensor(context, im2col, im2col_size);
+ if (im2col_status != kTfLiteOk) return im2col_status;
+ }
+
+ if (data->need_hwcn_weights) {
+ node->temporaries->data[data->hwcn_weights_index] = data->hwcn_weights_id;
+ TfLiteIntArray* hwcn_weights_size = TfLiteIntArrayCreate(2);
+
+ // Because we're treating the filter weights as a matrix when we do the
+ // transpose, we allocate the buffer with a two-dimensional shape, where one
+ // dimension is the number of elements in each filter, and the second is the
+ // total number of filters.
+ int input_depth = input->dims->data[3];
+ hwcn_weights_size->data[0] = (filter_height * filter_width * input_depth);
+ hwcn_weights_size->data[1] = channels_out;
+
+ TfLiteTensor* hwcn_weights =
+ &context->tensors[node->temporaries->data[data->hwcn_weights_index]];
+ hwcn_weights->type = data_type;
+ hwcn_weights->allocation_type = kTfLiteDynamic;
+ // Make sure we release any previous allocations before we reallocate.
+ // TODO(petewarden): Persistent arenas would be a better fit for this, but
+ // they aren't fully implemented yet.
+ if (hwcn_weights->data.raw) {
+ free(hwcn_weights->data.raw);
+ hwcn_weights->data.raw = nullptr;
+ }
+ auto hwcn_weights_status =
+ context->ResizeTensor(context, hwcn_weights, hwcn_weights_size);
+ if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status;
+ hwcn_weights->data.raw = static_cast<char*>(malloc(hwcn_weights->bytes));
+
+ // TODO(petewarden): If Resize() is called when the size hasn't actually
+ // changed, this will do extra redundant work.
+ data->have_weights_been_transposed = false;
+ }
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias,
+ TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
+ TfLiteTensor* output) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ auto input_offset = -input->params.zero_point;
+ auto filter_offset = -filter->params.zero_point;
+ auto output_offset = output->params.zero_point;
+
+ if (kernel_type == kReference) {
+ reference_ops::Conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ output_offset, data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ } else {
+ optimized_ops::Conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ output_offset, data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ }
+}
+
+template <KernelType kernel_type>
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
+ TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ const float* filter_data;
+ if (data->need_hwcn_weights) {
+ filter_data = GetTensorData<float>(hwcn_weights);
+ } else {
+ filter_data = GetTensorData<float>(filter);
+ }
+
+ if (kernel_type == kReference) {
+ reference_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input), filter_data,
+ GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height, data->padding.width,
+ data->padding.height, output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
+ } else {
+ multithreaded_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input), filter_data,
+ GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height, data->padding.width,
+ data->padding.height, params->padding, output_activation_min,
+ output_activation_max, GetTensorData<float>(output),
+ GetTensorDims(output), GetTensorData<float>(im2col),
+ GetTensorDims(im2col));
+ }
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ bool hasBias = node->inputs->size == 3;
+ TfLiteTensor* bias =
+ hasBias ? &context->tensors[node->inputs->data[2]] : nullptr;
+ TfLiteTensor* im2col =
+ data->need_im2col
+ ? &context->tensors[node->temporaries->data[data->im2col_index]]
+ : nullptr;
+ TfLiteTensor* hwcn_weights =
+ data->need_hwcn_weights
+ ? &context->tensors[node->temporaries->data[data->hwcn_weights_index]]
+ : nullptr;
+
+ if (data->need_hwcn_weights && !data->have_weights_been_transposed) {
+ TransposeFloatTensor(filter, hwcn_weights);
+ data->have_weights_been_transposed = true;
+ }
+
+ // TODO(aselle): Consider whether float conv and quantized conv should be
+ // separate ops to avoid dispatch overhead here.
+ switch (input->type) { // Already know in/outtypes are same.
+ case kTfLiteFloat32:
+ EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
+ im2col, hwcn_weights, output);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantized<kernel_type>(context, node, params, data, input, filter,
+ bias, im2col, hwcn_weights, output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace conv
+
+TfLiteRegistration* Register_CONVOLUTION_REF() {
+ static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
+ conv::Eval<conv::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() {
+ static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
+ conv::Eval<conv::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() {
+ static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
+ conv::Eval<conv::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_CONV_2D() {
+#ifdef USE_NEON
+ return Register_CONVOLUTION_NEON_OPT();
+#else
+ return Register_CONVOLUTION_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
new file mode 100644
index 0000000000..18d7a31d59
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -0,0 +1,440 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseConvolutionOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, bias, padding types,
+ // stride values.
+ BaseConvolutionOpModel(
+ const TensorData& input, const TensorData& filter,
+ const TensorData& output, int stride_width = 2, int stride_height = 2,
+ enum Padding padding = Padding_VALID,
+ enum ActivationFunctionType activation = ActivationFunctionType_NONE) {
+ input_ = AddInput(input);
+ filter_ = AddInput(filter);
+
+ int bias_size = GetShape(filter_)[0];
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+ if (input.type != TensorType_FLOAT32) {
+ // The following is required by quantized inference. It is the unittest's
+ // responsibility to make sure the output scale falls into the correct
+ // range.
+ CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
+ }
+
+ SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
+ CreateConv2DOptions(builder_, padding, stride_width,
+ stride_height, activation)
+ .Union());
+
+ BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
+ }
+
+ protected:
+ int input_;
+ int filter_;
+ int bias_;
+ int output_;
+};
+
+class ConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(ConvolutionOpTest, SimpleTestFloat32) {
+ ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {3, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ }));
+}
+
+TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
+ ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}},
+ /*stride_width=*/3, /*stride_height=*/1);
+ m.SetInput({
+ 3, 2, 1, -1, -2, -3, //
+ 4, 3, 2, -2, -3, -4, //
+ 5, 4, 3, -3, -4, -5, //
+ });
+ m.SetFilter({
+ 1, 2, //
+ 3, 4, //
+ });
+ m.SetBias({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 30, -24, //
+ 40, -34, //
+ }));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_SAME;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // No bias for this test.
+ m.SetBias({0});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
+ // This means we should end up with this matrix:
+ // | 105 | 150 | 183 | 95 |
+ // | 235 | 312 | 357 | 178 |
+ // | 187 | 234 | 261 | 121 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357,
+ 178, 187, 234, 261, 121}));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_SAME;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // Bias is | 10 |.
+ m.SetBias({10});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131
+ // This means we should end up with this matrix:
+ // | 115 | 160 | 193 | 105 |
+ // | 245 | 322 | 367 | 188 |
+ // | 197 | 244 | 271 | 131 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322,
+ 367, 188, 197, 244, 271, 131}));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_SAME;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
+ ActivationFunctionType_RELU);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // Bias is | -200 |.
+ m.SetBias({-200});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79
+ // All negative values are gated to zero by the Relu activation function.
+ // This means we should end up with this matrix:
+ // | 0 | 0 | 0 | 0 |
+ // | 35 | 112 | 157 | 0 |
+ // | 0 | 34 | 61 | 0 |
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0}));
+}
+
+TEST(ConvolutionOpTest, HandCalculatedValidFloat32) {
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const Padding padding = Padding_VALID;
+ ConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding);
+
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9});
+ // No bias for this test.
+ m.SetBias({0});
+
+ m.Invoke();
+ // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside
+ // the input because we're using the 'VALID' padding mode, giving a 2x1
+ // output.
+ // The calculations behind the expected output are:
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
+ // This means we should end up with this matrix:
+ // | 312 | 357 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357}));
+}
+
+class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(filter_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// In this tests we set the input and output scales so that the results
+// match exactly the 'non-quantized' version.
+TEST(ConvolutionOpTest, SimpleTestQuantized) {
+ QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
+ {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128});
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 1e-5)));
+ // For good measure, let's also verify the quantized values:
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 145, 129, 132, //
+ 145, 129, 132, //
+ 144, 131, 130, //
+ 164, 131, 130, //
+ }));
+}
+
+TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
+ QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
+ {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128},
+ /*stride_width=*/3, /*stride_height=*/1);
+ m.SetInput({
+ 3, 2, 1, -1, -2, -3, //
+ 4, 3, 2, -2, -3, -4, //
+ 5, 4, 3, -3, -4, -5, //
+ });
+ m.SetFilter({
+ 1, 2, //
+ 3, 4, //
+ });
+ m.SetBias({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
+ 30, -24, //
+ 40, -34, //
+ })));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 157, 103, //
+ 167, 93, //
+ }));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
new file mode 100644
index 0000000000..15dbfe08c8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -0,0 +1,289 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace depthwise_conv {
+
+constexpr int kInputTensor = 0;
+constexpr int kFilterTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// This file has three implementation of DepthwiseConv.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+struct OpData {
+ TfLitePaddingValues padding;
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multipler plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ // TODO(ahentz): use could use GetOptionalInputTensor() here, but we need to
+ // decide whether we are OK with optional tensors being completely absent, as
+ // opposed to having -1 as their index.
+ bool hasBias = NumInputs(node) == 3;
+
+ TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ TfLiteTensor* bias = nullptr;
+
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4);
+
+ // The parameter 'depth_multiplier' is redundant, so we check here to make
+ // sure it is consistent with the given dimensions.
+ TF_LITE_ENSURE_EQ(context,
+ params->depth_multiplier * SizeOfDimension(input, 3),
+ SizeOfDimension(filter, 3));
+
+ const TfLiteType data_type = input->type;
+ TF_LITE_ENSURE(context,
+ data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, data_type);
+ TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+
+ if (hasBias) {
+ bias = GetInput(context, node, kBiasTensor);
+ if (data_type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
+ } else {
+ TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ }
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3),
+ SizeOfDimension(bias, 0));
+ }
+
+ int channels_out = SizeOfDimension(filter, 3);
+ int width = SizeOfDimension(input, 2);
+ int height = SizeOfDimension(input, 1);
+ int filter_width = SizeOfDimension(filter, 2);
+ int filter_height = SizeOfDimension(filter, 1);
+ int batches = SizeOfDimension(input, 0);
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params->padding;
+ auto compute_out_size = [padding](int imageSize, int filterSize,
+ int stride) -> int {
+ return padding == kTfLitePaddingSame
+ ? (imageSize + stride - 1) / stride
+ : padding == kTfLitePaddingValid
+ ? (imageSize - filterSize + stride) / stride
+ : 0;
+ };
+
+ int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_height =
+ compute_out_size(height, filter_height, params->stride_height);
+
+ data->padding.height =
+ ComputePadding(params->stride_height, height, filter_height, out_height);
+ data->padding.width =
+ ComputePadding(params->stride_width, width, filter_width, out_width);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
+ outputSize->data[0] = batches;
+ outputSize->data[1] = out_height;
+ outputSize->data[2] = out_width;
+ outputSize->data[3] = channels_out;
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+template <KernelType kernel_type>
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
+ const Dims<4>&, const float*, const Dims<4>&, int, int,
+ int, int, int, float, float, float*, const Dims<4>&);
+ if (kernel_type == kReference) {
+ depthwise_conv = &reference_ops::DepthwiseConv;
+ } else {
+ depthwise_conv = &optimized_ops::DepthwiseConv;
+ }
+
+ depthwise_conv(
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ params->depth_multiplier, output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output));
+}
+
+template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ auto input_offset = -input->params.zero_point;
+ auto filter_offset = -filter->params.zero_point;
+ auto output_offset = output->params.zero_point;
+
+ void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
+ const Dims<4>&, int32, const int32*, const Dims<4>&,
+ int, int, int, int, int, int32, int32, int, int32,
+ int32, uint8*, const Dims<4>&);
+ if (kernel_type == kReference) {
+ depthwise_conv = &reference_ops::DepthwiseConv;
+ } else {
+ depthwise_conv = &optimized_ops::DepthwiseConv;
+ }
+
+ depthwise_conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, data->padding.width, data->padding.height,
+ params->depth_multiplier, output_offset, data->output_multiplier,
+ data->output_shift, data->output_activation_min,
+ data->output_activation_max, GetTensorData<uint8_t>(output),
+ GetTensorDims(output));
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ TfLiteTensor* bias =
+ (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
+
+ // TODO(aselle): Consider whether float conv and quantized conv should be
+ // separate ops to avoid dispatch overhead here.
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
+ output);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantized<kernel_type>(context, node, params, data, input, filter,
+ bias, output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace depthwise_conv
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF() {
+ static TfLiteRegistration r = {
+ depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare,
+ depthwise_conv::Eval<depthwise_conv::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare,
+ depthwise_conv::Eval<depthwise_conv::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT() {
+ static TfLiteRegistration r = {
+ depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare,
+ depthwise_conv::Eval<depthwise_conv::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D() {
+#ifdef USE_NEON
+ return Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+#else
+ return Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
new file mode 100644
index 0000000000..39227b2811
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, bias, padding types,
+ // stride values.
+ BaseDepthwiseConvolutionOpModel(const TensorData& input,
+ const TensorData& filter,
+ const TensorData& output) {
+ input_ = AddInput(input);
+ filter_ = AddInput(filter);
+
+ int bias_size = GetShape(filter_)[3];
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+ if (input.type != TensorType_FLOAT32) {
+ // The following is required by quantized inference. It is the unittest's
+ // responsibility to make sure the output scale falls into the correct
+ // range.
+ CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
+ }
+
+ int input_depth = GetShape(input_)[3];
+ int output_depth = GetShape(filter_)[3];
+ int depth_mul = output_depth / input_depth;
+
+ SetBuiltinOp(
+ BuiltinOperator_DEPTHWISE_CONV_2D,
+ BuiltinOptions_DepthwiseConv2DOptions,
+ CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
+ ActivationFunctionType_NONE)
+ .Union());
+
+ BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
+ }
+
+ protected:
+ int input_;
+ int filter_;
+ int bias_;
+ int output_;
+};
+
+class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
+ public:
+ using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
+
+ void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(DepthwiseConvolutionOpTest, SimpleTest) {
+ DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_FLOAT32, {1, 2, 2, 4}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ });
+ m.SetBias({1, 2, 3, 4});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 71, -34, 99, -20, //
+ 91, -26, 127, -4, //
+ }));
+}
+
+class QuantizedDepthwiseConvolutionOpModel
+ : public BaseDepthwiseConvolutionOpModel {
+ public:
+ using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(filter_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// In this test we set the input and output scales so that the results match
+// exactly the 'non-quantized' version.
+TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+ QuantizedDepthwiseConvolutionOpModel m(
+ {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128});
+
+ m.SetInput({
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ });
+ m.SetBias({1, 2, 3, 4});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 71, -34, 99, -20, //
+ 91, -26, 127, -4, //
+ },
+ 1e-5)));
+ // For good measure, let's also verify the quantized values:
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 198, 93, 226, 107, //
+ 218, 101, 254, 123, //
+ }));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
new file mode 100644
index 0000000000..4e8cb396d4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -0,0 +1,104 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Ops that looks up items from matrix.
+//
+// Input:
+// Tensor[0]: Row number to lookup, dim.size == 1, int32
+// Tensor[1]: 2-dimensional matrix of multi-dimensional items
+// dim.size >= 2, any data type.
+// first dimension is row, second dimension is column.
+//
+// Output:
+// Output.dim[0] == Tensor[0].dim[0], num of lookups
+// Output.dim[1] == Tensor[1].dim[1], num of items per row
+// Each item in output is a raw bytes copy of corresponding item in input.
+// When indices are out of bound, the ops will not succeed.
+//
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace embedding_lookup {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+
+ TfLiteTensor* value = GetInput(context, node, 1);
+ TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
+
+ outputSize->data[0] = SizeOfDimension(lookup, 0);
+ outputSize->data[1] = SizeOfDimension(value, 1);
+ for (int i = 2; i < NumDimensions(value); i++) {
+ outputSize->data[i] = SizeOfDimension(value, i);
+ }
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TfLiteTensor* value = GetInput(context, node, 1);
+
+ const int row_size = SizeOfDimension(value, 0);
+ const int row_bytes = value->bytes / row_size;
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ memcpy(output->data.raw + i * row_bytes,
+ value->data.raw + idx * row_bytes, row_bytes);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace embedding_lookup
+
+TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
+ static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
+ embedding_lookup::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
new file mode 100644
index 0000000000..6c770e7f71
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -0,0 +1,248 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Op that looks up items from a sparse tensor in an embedding matrix.
+// The sparse lookup tensor is represented by three individual tensors: lookup,
+// indices, and dense_shape. The representation assume that the corresponding
+// dense tensor would satisfy:
+// * dense.shape = dense_shape
+// * dense[tuple(indices[i])] = lookup[i]
+//
+// By convention, indices should be sorted.
+//
+// Options:
+// combiner: The reduction op (SUM, MEAN, SQRTN).
+// * SUM computes the weighted sum of the embedding results.
+// * MEAN is the weighted sum divided by the total weight.
+// * SQRTN is the weighted sum divided by the square root of the sum of the
+// squares of the weights.
+//
+// Input:
+// Tensor[0]: Ids to lookup, dim.size == 1, int32.
+// Tensor[1]: Indices, int32.
+// Tensor[2]: Dense shape, int32.
+// Tensor[3]: Weights to use for aggregation, float.
+// Tensor[4]: Params, a matrix of multi-dimensional items,
+// dim.size >= 2, float.
+//
+// Output:
+// A (dense) tensor representing the combined embeddings for the sparse ids.
+// For each row in the sparse tensor represented by (lookup, indices, shape)
+// the op looks up the embeddings for all ids in that row, multiplies them by
+// the corresponding weight, and combines these embeddings as specified in the
+// last dimension.
+//
+// Output.dim = [l0, ... , ln-1, e1, ..., em]
+// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em]
+//
+// For instance, if params is a 10x20 matrix and ids, weights are:
+//
+// [0, 0]: id 1, weight 2.0
+// [0, 1]: id 3, weight 0.5
+// [1, 0]: id 0, weight 1.0
+// [2, 3]: id 1, weight 3.0
+//
+// with combiner=MEAN, then the output will be a (3, 20) tensor where:
+//
+// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
+// output[1, :] = (params[0, :] * 1.0) / 1.0
+// output[2, :] = (params[1, :] * 3.0) / 3.0
+//
+// When indices are out of bound, the op will not succeed.
+
+#include <algorithm>
+#include <cmath>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* ids = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
+ TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
+
+ TfLiteTensor* indices = GetInput(context, node, 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
+ TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
+
+ TfLiteTensor* shape = GetInput(context, node, 2);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
+ TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
+
+ TfLiteTensor* weights = GetInput(context, node, 3);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
+ TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
+
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
+ SizeOfDimension(ids, 0));
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
+ SizeOfDimension(weights, 0));
+
+ TfLiteTensor* value = GetInput(context, node, 4);
+ TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
+
+ // Mark the output as a dynamic tensor.
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ output->allocation_type = kTfLiteDynamic;
+
+ return kTfLiteOk;
+}
+
+void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
+ float current_total_weight,
+ float current_squares_weight, int embedding_size,
+ float* output) {
+ if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) {
+ float multiplier = 1.0;
+ switch (combiner) {
+ case kTfLiteCombinerTypeMean:
+ multiplier = current_total_weight;
+ break;
+ case kTfLiteCombinerTypeSqrtn:
+ multiplier = std::sqrt(current_squares_weight);
+ break;
+ default:
+ break;
+ }
+ for (int k = 0; k < embedding_size; k++) {
+ output[k] /= multiplier;
+ }
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* ids = GetInput(context, node, 0);
+ TfLiteTensor* indices = GetInput(context, node, 1);
+ TfLiteTensor* dense_shape = GetInput(context, node, 2);
+ TfLiteTensor* weights = GetInput(context, node, 3);
+ TfLiteTensor* value = GetInput(context, node, 4);
+
+ const int lookup_rank = SizeOfDimension(indices, 1);
+ const int embedding_rank = NumDimensions(value);
+ const int num_lookups = SizeOfDimension(ids, 0);
+ const int num_rows = SizeOfDimension(value, 0);
+
+ // The last dimension gets replaced by the embedding.
+ const int output_rank = (lookup_rank - 1) + (embedding_rank - 1);
+
+ // Make sure that the actual dense shape of the sparse tensor represented by
+ // (loopkup, indices, dense_shape) is consistent.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank);
+
+ // Resize output tensor.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
+ int k = 0;
+ int embedding_size = 1;
+ int lookup_size = 1;
+ for (int i = 0; i < lookup_rank - 1; i++, k++) {
+ const int dim = dense_shape->data.i32[i];
+ lookup_size *= dim;
+ output_shape->data[k] = dim;
+ }
+ for (int i = 1; i < embedding_rank; i++, k++) {
+ const int dim = SizeOfDimension(value, i);
+ embedding_size *= dim;
+ output_shape->data[k] = dim;
+ }
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
+ const int output_size = lookup_size * embedding_size;
+ TfLiteTensorRealloc(output_size * sizeof(float), output);
+
+ tensor_utils::ZeroVector(output->data.f, output_size);
+
+ // Keep track of the current bucket for aggregation/combination.
+ int current_output_offset = 0;
+ float current_total_weight = 0.0;
+ float current_squares_weight = 0.0;
+ int num_elements = 0;
+
+ for (int i = 0; i < num_lookups; i++) {
+ int idx = ids->data.i32[i];
+ if (idx >= num_rows || idx < 0) {
+ context->ReportError(context,
+ "Embedding Lookup Sparse: index out of bounds.");
+ return kTfLiteError;
+ }
+
+ // Check where we need to aggregate.
+ const int example_indices_offset = i * lookup_rank;
+ int output_bucket = 0;
+ int stride = 1;
+ for (int k = (lookup_rank - 1) - 1; k >= 0; k--) {
+ output_bucket += indices->data.i32[example_indices_offset + k] * stride;
+ stride *= dense_shape->data.i32[k];
+ }
+ const int output_offset = output_bucket * embedding_size;
+
+ // If we are in a new aggregation bucket and the combiner is not the sum,
+ // go back and finalize the result of the previous bucket.
+ if (output_offset != current_output_offset) {
+ FinalizeAggregation(params->combiner, num_elements, current_total_weight,
+ current_squares_weight, embedding_size,
+ &output->data.f[current_output_offset]);
+
+ // Track next bucket.
+ num_elements = 0;
+ current_total_weight = 0.0;
+ current_squares_weight = 0.0;
+ current_output_offset = output_offset;
+ }
+
+ // Add element to aggregation.
+ ++num_elements;
+ const int example_embedding_offset = idx * embedding_size;
+ const float w = weights->data.f[i];
+ current_squares_weight += w * w;
+ current_total_weight += w;
+ for (int k = 0; k < embedding_size; k++) {
+ output->data.f[current_output_offset + k] +=
+ (value->data.f[example_embedding_offset + k] * w);
+ }
+ }
+
+ // Finalize last bucket.
+ FinalizeAggregation(params->combiner, num_elements, current_total_weight,
+ current_squares_weight, embedding_size,
+ &output->data.f[current_output_offset]);
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() {
+ static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc
new file mode 100644
index 0000000000..69d9c5cc7d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc
@@ -0,0 +1,166 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite sparse lookup op.
+
+#include <cmath>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class EmbeddingLookupSparseOpModel : public SingleOpModel {
+ public:
+ EmbeddingLookupSparseOpModel(CombinerType type,
+ std::initializer_list<int> lookup_shape,
+ std::initializer_list<int> indices_shape,
+ std::initializer_list<int> dense_shape_shape,
+ std::initializer_list<int> value_shape) {
+ lookup_ = AddInput(TensorType_INT32);
+ indices_ = AddInput(TensorType_INT32);
+ dense_shape_ = AddInput(TensorType_INT32);
+ weights_ = AddInput(TensorType_FLOAT32);
+ value_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
+ BuiltinOptions_EmbeddingLookupSparseOptions,
+ CreateEmbeddingLookupSparseOptions(builder_, type).Union());
+ BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape,
+ lookup_shape, value_shape});
+ }
+
+ void SetInput(std::initializer_list<int> lookup_data,
+ std::initializer_list<int> indices_data,
+ std::initializer_list<int> dense_shape_data,
+ std::initializer_list<float> weights_data) {
+ PopulateTensor(lookup_, lookup_data);
+ PopulateTensor(indices_, indices_data);
+ PopulateTensor(dense_shape_, dense_shape_data);
+ PopulateTensor(weights_, weights_data);
+ }
+
+ void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(value_);
+ int rows = tensor->dims->data[0];
+ int columns = tensor->dims->data[1];
+ int features = tensor->dims->data[2];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < columns; j++) {
+ for (int k = 0; k < features; k++) {
+ tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
+ }
+ }
+ }
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int lookup_;
+ int weights_;
+ int indices_;
+ int dense_shape_;
+ int value_;
+ int output_;
+};
+
+TEST(EmbeddingLookupOpTest, SimpleTest) {
+ EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // -
+ 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0
+ })));
+}
+
+TEST(EmbeddingLookupOpTest, SimpleTestMean) {
+ EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2},
+ {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // -
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0
+ })));
+}
+
+TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) {
+ EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2},
+ {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // -
+ 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f),
+ 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f),
+ 7.20f / std::sqrt(20.0f),
+ 7.26f /
+ std::sqrt(
+ 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0
+ })));
+}
+
+TEST(EmbeddingLookupOpTest, Indices3DTest) {
+ EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2});
+ m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2},
+ {1.0, 2.0, 4.0});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00,
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
+ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60,
+ 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
+ })));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+#ifdef OS_LINUX
+ tflite::LogToStderr();
+#endif
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
new file mode 100644
index 0000000000..8c030b0677
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Lookup op.
+
+#include <iomanip>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class EmbeddingLookupOpModel : public SingleOpModel {
+ public:
+ EmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape) {
+ input_ = AddInput(TensorType_INT32);
+ weight_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
+ BuildInterpreter({index_shape, weight_shape});
+ }
+
+ void SetInput(std::initializer_list<int> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(weight_);
+ int rows = tensor->dims->data[0];
+ int columns = tensor->dims->data[1];
+ int features = tensor->dims->data[2];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < columns; j++) {
+ for (int k = 0; k < features; k++) {
+ tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
+ }
+ }
+ }
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int weight_;
+ int output_;
+};
+
+// TODO(ahentz): write more tests that exercise the details of the op, such as
+// lookup errors and variable input shapes.
+TEST(EmbeddingLookupOpTest, SimpleTest) {
+ EmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.PopulateTensor<int>(0, {1, 0, 2});
+ m.Set3DWeightMatrix(
+ [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ })));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
new file mode 100644
index 0000000000..a77fe94e49
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -0,0 +1,307 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace fully_connected {
+
+// This file has four implementations of FullyConnected
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+ kPie, // Used by the PIE team
+};
+
+struct OpData {
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multipler plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ gemm_support::IncrementUsageCounter(context);
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ gemm_support::DecrementUsageCounter(context);
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ int input_size = 1;
+ for (int i = 0; i < input->dims->size; i++) {
+ input_size *= input->dims->data[i];
+ }
+
+ const int batch_size = input_size / filter->dims->data[1];
+ const int num_units = filter->dims->data[0];
+
+ TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]);
+ if (bias) {
+ TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
+ }
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ TfLiteType data_type = input->type;
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
+ &data->output_shift);
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size_array));
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ int total_input_size = 1;
+ for (int i = 0; i < input->dims->size; i++) {
+ total_input_size *= input->dims->data[i];
+ }
+
+ int input_size = filter->dims->data[1];
+ const int batch_size = total_input_size / filter->dims->data[1];
+ const int num_units = filter->dims->data[0];
+
+ // Output = bias if bias tensor exists.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // Compute output += weight * input
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ filter->data.f, num_units, input_size, input->data.f, batch_size,
+ output->data.f, /*result_stride=*/1);
+
+ // Apply activation function
+ tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
+ params->activation, output->data.f);
+
+ return kTfLiteOk;
+}
+
+#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
+ if (params->activation == kTfLiteActNone) { \
+ macro_name(target_namespace, kNone); \
+ } \
+ if (params->activation == kTfLiteActRelu) { \
+ macro_name(target_namespace, kRelu); \
+ } \
+ if (params->activation == kTfLiteActRelu6) { \
+ macro_name(target_namespace, kRelu6); \
+ }
+
+template <KernelType kernel_type>
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ int32_t input_offset = -input->params.zero_point;
+ int32_t filter_offset = -filter->params.zero_point;
+ int32_t output_offset = output->params.zero_point;
+#define TF_LITE_FULLY_CONNECTED(type) \
+ type::FullyConnected( \
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
+ data->output_multiplier, data->output_shift, \
+ data->output_activation_min, data->output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output), gemm_context)
+ if (kernel_type == kReference) {
+ TF_LITE_FULLY_CONNECTED(reference_ops);
+ } else if (kernel_type == kPie) {
+ // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
+ // we just defer to the MINI ones.
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ } else {
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_FULLY_CONNECTED(type) \
+ type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
+ GetTensorData<float>(filter), GetTensorDims(filter), \
+ GetTensorData<float>(bias), GetTensorDims(bias), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_FULLY_CONNECTED(reference_ops);
+ } else if (kernel_type == kPie) {
+ return EvalPie(context, node, params, data, input, filter, bias, output);
+ } else {
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+#undef TF_LITE_MACRO_DISPATCH
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ return EvalFloat<kernel_type>(context, node, params, data, input, filter,
+ bias, output);
+ case kTfLiteUInt8:
+ return EvalQuantized<kernel_type>(context, node, params, data, input,
+ filter, bias, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace fully_connected
+
+TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
+ static TfLiteRegistration r = {
+ fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() {
+ static TfLiteRegistration r = {
+ fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
+ static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
+ fully_connected::Prepare,
+ fully_connected::Eval<fully_connected::kPie>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED() {
+ // TODO(ahentz): We don't have a dedicated quantized version of the PIE
+ // kernel. For now, the quantized version just defer to the corresponding
+ // optimized MINI kernel. At some point we will allow different libraries to
+ // be built with different kernels, but for now we have to pick one here.
+ return Register_FULLY_CONNECTED_PIE();
+#ifdef USE_NEON
+ return Register_FULLY_CONNECTED_NEON_OPT();
+#else
+ return Register_FULLY_CONNECTED_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
new file mode 100644
index 0000000000..112e3f1ba0
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -0,0 +1,377 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite FULLY_CONNECTED op.
+
+#include <iomanip>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+static float fully_connected_input[] = {
+ 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
+ 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
+ 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
+ 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
+ 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
+ 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
+ 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
+ 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
+ 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
+ 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
+ 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
+ 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
+ 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
+ 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
+ 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
+ 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
+ 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
+ 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
+ 0.921330, 0.139902};
+
+static float fully_connected_golden_output[] = {
+ 0, 0.0732134, 0, 0, 0, 0.280859,
+ 0, 0.128927, 0, 0.0777251, 0, 0.270268,
+ 0.271435, 0.0173503, 0.335465, 0.235562,
+
+ 0, 0.0745866, 0, 0.051611, 0, 0.253876,
+ 0, 0.0814873, 0, 0.104104, 0, 0.248529,
+ 0.264194, 0, 0.302973, 0.166252,
+
+ 0, 0.0170409, 0, 0.0509851, 0, 0.212834,
+ 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428,
+ 0.298051, 0, 0.332233, 0.00445903,
+
+ 0, 0.125246, 0, 0.0735336, 0, 0.0910256,
+ 0, 0, 0, 0.18933, 0.378111, 0.0712443,
+ 0.277298, 0.0123414, 0.267454, 0,
+
+ 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256,
+ 0, 0, 0, 0.156412, 0.434914, 0.0461529,
+ 0.246508, 0, 0.363138, 0,
+
+ 0, 0, 0, 0.0212949, 0, 0.301708,
+ 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195,
+ 0.197161, 0, 0.37316, 0,
+
+ 0, 0.221783, 0, 0, 0.0116515, 0.281945,
+ 0, 0, 0, 0, 0.285626, 0.181773,
+ 0.296401, 0.170452, 0.367135, 0.142597,
+
+ 0, 0, 0, 0, 0, 0.418886,
+ 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589,
+ 0.398286, 0.177146, 0.40359, 0.121452,
+
+ 0, 0.0834884, 0, 0, 0, 0.287441,
+ 0, 0.0046838, 0, 0.0122087, 0, 0.217376,
+ 0.140183, 0.0948412, 0.436677, 0.0589876,
+
+ 0, 0.0289969, 0, 0.0921397, 0, 0.396802,
+ 0, 0.0126157, 0, 0.0968433, 0, 0.172271,
+ 0.173295, 0.0664741, 0.53645, 0.00915603,
+
+ 0, 0, 0, 0, 0, 0.147942,
+ 0, 0.263795, 0, 0.39782, 0, 0.382435,
+ 0.561072, 0.0579847, 0.145712, 0.13508,
+
+ 0, 0, 0, 0.16382, 0, 0.322294,
+ 0, 0.163798, 0, 0.405211, 0.367953, 0.076852,
+ 0.342473, 0.0834118, 0.377537, 0,
+
+ 0, 0.206, 0, 0, 0, 0.375769,
+ 0, 0, 0, 0, 0, 0.125165,
+ 0, 0.105591, 0.52055, 0.0536445,
+
+ 0, 0.259261, 0, 0, 0, 0.247707,
+ 0, 0, 0, 0, 0, 0.215862,
+ 0.149153, 0.224678, 0.359519, 0.129419,
+
+ 0, 0.17611, 0, 0.280895, 0, 0.576484,
+ 0, 0.000418848, 0, 0, 0, 0.151112,
+ 0.211902, 0, 0.566341, 0.106305,
+
+ 0, 0.0246284, 0, 0, 0, 0.196267,
+ 0, 0.0248624, 0, 0.265635, 0, 0.436199,
+ 0.408079, 0.134514, 0.328489, 0.411368};
+
+class BaseFullyConnectedOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): test different activation types too.
+ BaseFullyConnectedOpModel(int units, int batches, const TensorData& input,
+ const TensorData& output = {TensorType_FLOAT32})
+ : batches_(batches), units_(units) {
+ int total_input_size = 1;
+ for (int i = 0; i < input.shape.size(); ++i) {
+ total_input_size *= input.shape[i];
+ }
+ input_size_ = total_input_size / batches_;
+
+ input_ = AddInput(input);
+ weights_ =
+ AddInput({input.type, {units_, input_size_}, input.min, input.max});
+
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {units_}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(weights_);
+ TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
+ CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+ .Union());
+ BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+ }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ protected:
+ int input_;
+ int weights_;
+ int bias_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+};
+
+class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
+ public:
+ using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetWeights(std::initializer_list<float> f) {
+ PopulateTensor(weights_, f);
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
+ public:
+ using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+ void SetWeights(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(weights_, data);
+ }
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// TODO(ahentz): add more small tests like this one, focused on making sure the
+// calculations are correct.
+TEST(FullyConnectedOpTest, SimpleTest) {
+ FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}});
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
+}
+
+TEST(FullyConnectedOpTest, SimpleTestQuantized) {
+ QuantizedFullyConnectedOpModel m(
+ 3, 2,
+ /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
+ /*output=*/{TensorType_UINT8, {}, -127, 128});
+
+ // input_product_scale < output_scale was not true.
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+TEST(FullyConnectedOpTest, SimpleTest4DInput) {
+ // Note that it is not required that the first dimension be the number of
+ // batches. All we care is that the input can be evenly distributed in
+ // batches. In this case, we need the input to have multiples of '2'.
+ FloatFullyConnectedOpModel m(/*units=*/3,
+ /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 24, 25, 26, // first batch
+ 58, 59, 60, // second batch
+ }));
+}
+
+TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
+ QuantizedFullyConnectedOpModel m(
+ 3, 2,
+ /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
+ /*output=*/{TensorType_UINT8, {}, -127, 128});
+
+ // input_product_scale < output_scale was not true.
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
+// to debug errors and doesn't necessarily test all the important details.
+TEST(FullyConnectedOpTest, BlackBoxTest) {
+ FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}});
+ m.SetWeights(
+ {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636,
+ -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504,
+ -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307,
+ -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
+ 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782,
+ -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638,
+ -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540,
+ -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382,
+ -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
+ -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308,
+ 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958,
+ 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863,
+ -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612,
+ 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583,
+ 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284,
+ 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480,
+ -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041,
+ 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764,
+ -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627,
+ 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172,
+ 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324,
+ 0.145408, -0.221897});
+ m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
+ 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
+ 0.048478, -0.032270, 0.175688, -0.085662});
+
+ const int input_sequence_size = sizeof(fully_connected_input) /
+ sizeof(float) /
+ (m.input_size() * m.num_batches());
+ for (int i = 0; i < input_sequence_size; i++) {
+ // TODO(ahentz): This is what the original test was doing: two equal
+ // batches per invocation. We could instead use two different batches.
+ float* batch_start = fully_connected_input + i * m.input_size();
+ float* batch_end = batch_start + m.input_size();
+ m.SetInput(0, batch_start, batch_end);
+ m.SetInput(m.input_size(), batch_start, batch_end);
+
+ m.Invoke();
+
+ float* golden_start = fully_connected_golden_output + i * m.num_units();
+ float* golden_end = golden_start + m.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
new file mode 100644
index 0000000000..eb2b0aacf7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace gemm_support {
+
+struct RefCountedGemmContext {
+ gemmlowp::GemmContext* gemm_context_ = nullptr;
+ int num_references_ = 0;
+};
+
+void IncrementUsageCounter(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ if (ptr == nullptr) {
+ ptr = new RefCountedGemmContext;
+ ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->num_references_ = 0;
+ context->gemm_context = ptr;
+ }
+ ptr->num_references_++;
+}
+
+void DecrementUsageCounter(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to DecrementUsageCounter() not preceded by "
+ "IncrementUsageCounter()");
+ }
+ if (--ptr->num_references_ == 0) {
+ delete ptr->gemm_context_;
+ delete ptr;
+ context->gemm_context = nullptr;
+ }
+}
+
+gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
+ auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->gemm_context_;
+}
+
+void SetMaxNumThreads(TfLiteContext* context, int num_threads) {
+ IncrementUsageCounter(context);
+ GetFromContext(context)->set_max_num_threads(num_threads);
+ DecrementUsageCounter(context);
+}
+
+} // namespace gemm_support
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
new file mode 100644
index 0000000000..b531959ffb
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace gemm_support {
+
+// Returns the GemmContext stored in 'context', allowing multiple ops to
+// share a single object, as long as they share a TfLiteContext. The caller
+// must ensure that this is called between IncrementUsageCounter() and
+// DecrementUsageCounter(). For example, in the implementation of an op:
+// void* Init(TfLiteContext* context, const char*, size_t) {
+// gemm_support::IncrementUsageCounter(context);
+// return nullptr;
+// }
+// void Free(TfLiteContext* context, void*) {
+// gemm_support::DecrementUsageCounter(context);
+// }
+// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+// auto* gemm_context = gemm_support::GetFromContext(context);
+// }
+gemmlowp::GemmContext* GetFromContext(TfLiteContext* context);
+
+// Let the framework know that the GemmContext stored in 'context' will be used
+// by an op. If necessary a new GemmContext is created and placed in 'context'.
+void IncrementUsageCounter(TfLiteContext* context);
+
+// Let the framework know that the op stopped using the GemmContext stored in
+// 'context'. If there are no more usages the GemmContext will be deleted.
+void DecrementUsageCounter(TfLiteContext* context);
+
+// Set the maximum number threads available for gemmlowp operations.
+void SetMaxNumThreads(TfLiteContext* context, int num_threads);
+
+} // namespace gemm_support
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
new file mode 100644
index 0000000000..3b82601d11
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Op that looks up items from hashtable.
+//
+// Input:
+// Tensor[0]: Hash key to lookup, dim.size == 1, int32
+// Tensor[1]: Key of hashtable, dim.size == 1, int32
+// *MUST* be sorted in ascending order.
+// Tensor[2]: Value of hashtable, dim.size >= 1
+// Tensor[1].Dim[0] == Tensor[2].Dim[0]
+//
+// Output:
+// Output[0].dim[0] == Tensor[0].dim[0], num of lookups
+// Each item in output is a raw bytes copy of corresponding item in input.
+// When key does not exist in hashtable, the returned bytes are all 0s.
+//
+// Output[1].dim = { Tensor[0].dim[0] }, num of lookups
+// Each item indicates whether the corresponding lookup has a returned value.
+// 0 for missing key, 1 for found key.
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+
+int greater(const void* a, const void* b) {
+ return *static_cast<const int*>(a) - *static_cast<const int*>(b);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
+
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+
+ TfLiteTensor* key = GetInput(context, node, 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
+ TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
+
+ TfLiteTensor* value = GetInput(context, node, 2);
+ TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
+ SizeOfDimension(value, 0));
+ if (value->type == kTfLiteString) {
+ TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
+ }
+
+ TfLiteTensor* hits = GetOutput(context, node, 1);
+ TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
+ TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
+ hitSize->data[0] = SizeOfDimension(lookup, 0);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, value->type, output->type);
+
+ TfLiteStatus status = kTfLiteOk;
+ if (output->type != kTfLiteString) {
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
+ outputSize->data[0] = SizeOfDimension(lookup, 0);
+ for (int i = 1; i < NumDimensions(value); i++) {
+ outputSize->data[i] = SizeOfDimension(value, i);
+ }
+ status = context->ResizeTensor(context, output, outputSize);
+ }
+ if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) {
+ status = kTfLiteError;
+ }
+ return status;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* hits = GetOutput(context, node, 1);
+ TfLiteTensor* lookup = GetInput(context, node, 0);
+ TfLiteTensor* key = GetInput(context, node, 1);
+ TfLiteTensor* value = GetInput(context, node, 2);
+
+ const int num_rows = SizeOfDimension(value, 0);
+ const int row_bytes = value->bytes / num_rows;
+ void* pointer = nullptr;
+ DynamicBuffer buf;
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = -1;
+ pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows,
+ sizeof(int32_t), greater);
+ if (pointer != nullptr) {
+ idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) /
+ sizeof(int32_t);
+ }
+
+ if (idx >= num_rows || idx < 0) {
+ if (output->type == kTfLiteString) {
+ buf.AddString(nullptr, 0);
+ } else {
+ memset(output->data.raw + i * row_bytes, 0, row_bytes);
+ }
+ hits->data.uint8[i] = 0;
+ } else {
+ if (output->type == kTfLiteString) {
+ buf.AddString(GetString(value, idx));
+ } else {
+ memcpy(output->data.raw + i * row_bytes,
+ value->data.raw + idx * row_bytes, row_bytes);
+ }
+ hits->data.uint8[i] = 1;
+ }
+ }
+ if (output->type == kTfLiteString) {
+ buf.WriteToTensor(output);
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
+ static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc
new file mode 100644
index 0000000000..916a23225e
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc
@@ -0,0 +1,176 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Lookup op.
+
+#include <iomanip>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class HashtableLookupOpModel : public SingleOpModel {
+ public:
+ HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
+ std::initializer_list<int> key_shape,
+ std::initializer_list<int> value_shape,
+ TensorType type) {
+ lookup_ = AddInput(TensorType_INT32);
+ key_ = AddInput(TensorType_INT32);
+ value_ = AddInput(type);
+ output_ = AddOutput(type);
+ hit_ = AddOutput(TensorType_UINT8);
+ SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
+ BuildInterpreter({lookup_shape, key_shape, value_shape});
+ }
+
+ void SetLookup(std::initializer_list<int> data) {
+ PopulateTensor<int>(lookup_, data);
+ }
+
+ void SetHashtableKey(std::initializer_list<int> data) {
+ PopulateTensor<int>(key_, data);
+ }
+
+ void SetHashtableValue(const std::vector<string>& content) {
+ PopulateStringTensor(value_, content);
+ }
+
+ void SetHashtableValue(const std::function<float(int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(value_);
+ int rows = tensor->dims->data[0];
+ for (int i = 0; i < rows; i++) {
+ tensor->data.f[i] = function(i);
+ }
+ }
+
+ void SetHashtableValue(const std::function<float(int, int)>& function) {
+ TfLiteTensor* tensor = interpreter_->tensor(value_);
+ int rows = tensor->dims->data[0];
+ int features = tensor->dims->data[1];
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < features; j++) {
+ tensor->data.f[i * features + j] = function(i, j);
+ }
+ }
+ }
+
+ std::vector<string> GetStringOutput() {
+ TfLiteTensor* output = interpreter_->tensor(output_);
+ int num = GetStringCount(output);
+ std::vector<string> result(num);
+ for (int i = 0; i < num; i++) {
+ auto ref = GetString(output, i);
+ result[i] = string(ref.str, ref.len);
+ }
+ return result;
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }
+
+ private:
+ int lookup_;
+ int key_;
+ int value_;
+ int output_;
+ int hit_;
+};
+
+// TODO(yichengfan): write more tests that exercise the details of the op,
+// such as lookup errors and variable input shapes.
+TEST(HashtableLookupOpTest, Test2DInput) {
+ HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetHashtableKey({-11, 0, 1234});
+ m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 2.0, 2.1, // 2-nd item
+ 0, 0, // Not found
+ 0.0, 0.1, // 0-th item
+ 1.0, 1.1, // 1-st item
+ })));
+ EXPECT_THAT(m.GetHit(), ElementsAreArray({
+ 1, 0, 1, 1,
+ }));
+}
+
+TEST(HashtableLookupOpTest, Test1DInput) {
+ HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetHashtableKey({-11, 0, 1234});
+ m.SetHashtableValue([](int i) { return i * i / 10.0f; });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.4, // 2-nd item
+ 0, // Not found
+ 0.0, // 0-th item
+ 0.1, // 1-st item
+ })));
+ EXPECT_THAT(m.GetHit(), ElementsAreArray({
+ 1,
+ 0,
+ 1,
+ 1,
+ }));
+}
+
+TEST(HashtableLookupOpTest, TestString) {
+ HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING);
+
+ m.SetLookup({1234, -292, -11, 0});
+ m.SetHashtableKey({-11, 0, 1234});
+ m.SetHashtableValue({"Hello", "", "Hi"});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({
+ "Hi", // 2-nd item
+ "", // Not found
+ "Hello", // 0-th item
+ "", // 1-st item
+ }));
+ EXPECT_THAT(m.GetHit(), ElementsAreArray({
+ 1,
+ 0,
+ 1,
+ 1,
+ }));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
new file mode 100644
index 0000000000..288534099b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -0,0 +1,359 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+tflite_deps_intel = [
+ "@arm_neon_2_x86_sse",
+]
+
+NEON_FLAGS_IF_APPLICABLE = select({
+ ":arm": [
+ "-O3",
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ ],
+ ":armeabi-v7a": [
+ "-O3",
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ ],
+ ":armv7a": [
+ "-O3",
+ "-mfpu=neon",
+ "-mfloat-abi=softfp",
+ ],
+ "//conditions:default": [
+ "-O3",
+ ],
+})
+
+cc_library(
+ name = "types",
+ srcs = [],
+ hdrs = [
+ "compatibility.h",
+ "types.h",
+ ],
+)
+
+config_setting(
+ name = "arm",
+ values = {
+ "cpu": "arm",
+ },
+)
+
+config_setting(
+ name = "arm64-v8a",
+ values = {
+ "cpu": "arm64-v8a",
+ },
+)
+
+config_setting(
+ name = "armv7a",
+ values = {
+ "cpu": "armv7a",
+ },
+)
+
+config_setting(
+ name = "armeabi-v7a",
+ values = {
+ "cpu": "armeabi-v7a",
+ },
+)
+
+config_setting(
+ name = "haswell",
+ values = {
+ "cpu": "haswell",
+ },
+)
+
+config_setting(
+ name = "ios_x86_64",
+ values = {
+ "cpu": "ios_x86_64",
+ },
+)
+
+config_setting(
+ name = "ios_armv7",
+ values = {
+ "cpu": "ios_armv7",
+ },
+)
+
+config_setting(
+ name = "ios_arm64",
+ values = {
+ "cpu": "ios_arm64",
+ },
+)
+
+config_setting(
+ name = "k8",
+ values = {
+ "cpu": "k8",
+ },
+)
+
+config_setting(
+ name = "x86",
+ values = {
+ "cpu": "x86",
+ },
+)
+
+config_setting(
+ name = "x86_64",
+ values = {
+ "cpu": "x86_64",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {
+ "cpu": "darwin",
+ },
+)
+
+cc_library(
+ name = "optimized_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "optimized/depthwiseconv_float.h",
+ "optimized/depthwiseconv_uint8.h",
+ "optimized/optimized_ops.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":types",
+ ":round",
+ "//third_party/eigen3",
+ "@gemmlowp//:gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "optimized",
+ hdrs = [
+ "optimized/eigen_spatial_convolutions.h",
+ "optimized/eigen_tensor_reduced_instantiations_oss.h",
+ "optimized/multithreaded_conv.h",
+ "tensor.h",
+ ],
+ deps = [
+ ":optimized_base",
+ ":types",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:context",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_test(
+ name = "tensor_test",
+ srcs = ["tensor_test.cc"],
+ deps = [
+ ":reference",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "round",
+ srcs = [],
+ hdrs = ["round.h"],
+)
+
+cc_library(
+ name = "quantization_util",
+ srcs = ["quantization_util.cc"],
+ hdrs = [
+ "compatibility.h",
+ "quantization_util.h",
+ ],
+ deps = [":round"],
+)
+
+cc_test(
+ name = "quantization_util_test",
+ srcs = ["quantization_util_test.cc"],
+ deps = [
+ ":quantization_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "reference_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "reference/depthwiseconv_float.h",
+ "reference/depthwiseconv_uint8.h",
+ "reference/reference_ops.h",
+ ],
+ deps = [
+ ":round",
+ ":types",
+ "//third_party/eigen3",
+ "@gemmlowp//:gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "reference",
+ hdrs = ["tensor.h"],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite:context",
+ ],
+)
+
+cc_library(
+ name = "portable_tensor_utils",
+ srcs = [
+ "reference/portable_tensor_utils.cc",
+ ],
+ hdrs = [
+ "reference/portable_tensor_utils.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/kernels:activation_functor",
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ ],
+)
+
+cc_library(
+ name = "neon_tensor_utils",
+ srcs = [
+ "optimized/neon_tensor_utils.cc",
+ ],
+ hdrs = [
+ "optimized/neon_tensor_utils.h",
+ "optimized/tensor_utils_impl.h",
+ ],
+ copts = NEON_FLAGS_IF_APPLICABLE,
+ deps = [
+ ":cpu_check",
+ ":portable_tensor_utils",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/kernels:activation_functor",
+ ],
+)
+
+cc_library(
+ name = "tensor_utils",
+ srcs = [
+ "tensor_utils.cc",
+ ],
+ hdrs = [
+ "optimized/tensor_utils_impl.h",
+ "reference/portable_tensor_utils.h",
+ "tensor_utils.h",
+ ],
+ copts = NEON_FLAGS_IF_APPLICABLE,
+ deps = [
+ "//tensorflow/contrib/lite/kernels:activation_functor",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":arm": [
+ ":neon_tensor_utils",
+ ],
+ ":arm64-v8a": [
+ ":neon_tensor_utils",
+ ],
+ ":armeabi-v7a": [
+ ":neon_tensor_utils",
+ ],
+ ":armv7a": [
+ ":neon_tensor_utils",
+ ],
+ ":ios_armv7": [
+ ":neon_tensor_utils",
+ ],
+ ":ios_arm64": [
+ ":neon_tensor_utils",
+ ],
+ "//conditions:default": [
+ ":portable_tensor_utils",
+ ],
+ }),
+)
+
+cc_test(
+ name = "tensor_utils_test",
+ srcs = ["tensor_utils_test.cc"],
+ copts = NEON_FLAGS_IF_APPLICABLE,
+ linkopts = select({
+ "//tensorflow:android": [
+ "-fPIE -pie",
+ ],
+ "//conditions:default": [],
+ }),
+ linkstatic = 1,
+ deps = [
+ ":tensor_utils",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "cpu_check",
+ hdrs = [
+ "optimized/cpu_check.h",
+ ],
+ deps = [
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "@androidndk//:cpufeatures",
+ ],
+ "//conditions:default": [],
+ },
+ ),
+)
+
+exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
new file mode 100644
index 0000000000..28f19a2506
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif
+#endif
+
+#ifndef USE_NEON
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#include <arm_neon.h>
+#endif
+
+#if defined __GNUC__ && defined __SSE4_1__
+#define USE_NEON
+
+#define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#pragma GCC diagnostic ignored "-Wattributes"
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wnarrowing"
+#pragma GCC diagnostic ignored "-Wsequence-point"
+
+#include "NEON_2_SSE.h"
+
+#pragma GCC diagnostic pop
+#endif
+#endif
+
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+inline void GetActivationMinMax(FusedActivationFunctionType ac,
+ float* output_activation_min,
+ float* output_activation_max) {
+ switch (ac) {
+ case FusedActivationFunctionType::kNone:
+ *output_activation_min = std::numeric_limits<float>::lowest();
+ *output_activation_max = std::numeric_limits<float>::max();
+ break;
+ case FusedActivationFunctionType::kRelu:
+ *output_activation_min = 0.f;
+ *output_activation_max = std::numeric_limits<float>::max();
+ break;
+ case FusedActivationFunctionType::kRelu1:
+ *output_activation_min = -1.f;
+ *output_activation_max = 1.f;
+ break;
+ case FusedActivationFunctionType::kRelu6:
+ *output_activation_min = 0.f;
+ *output_activation_max = 6.f;
+ break;
+ }
+}
+
+inline float ActivationFunctionWithMinMax(float x, float output_activation_min,
+ float output_activation_max) {
+ return std::min(std::max(x, output_activation_min), output_activation_max);
+}
+
+// Legacy function, left for compatibility only.
+template <FusedActivationFunctionType Ac>
+float ActivationFunction(float x) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ return ActivationFunctionWithMinMax(x, output_activation_min,
+ output_activation_max);
+}
+
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
+ int32 x, int32 quantized_multiplier, int right_shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
+ int32 x, int32 quantized_multiplier, int left_shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+ quantized_multiplier);
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
new file mode 100644
index 0000000000..796a03566a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+
+#include <cassert>
+#include <cstdint>
+#include <cstdlib>
+
+#ifndef TFLITE_DCHECK
+#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_EQ
+#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_GE
+#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_GT
+#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_LE
+#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false)
+#endif
+
+#ifndef TFLITE_DCHECK_LT
+#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false)
+#endif
+
+// TODO(ahentz): Clean up: We should stick to the DCHECK versions.
+#ifndef TFLITE_CHECK
+#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_EQ
+#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_GE
+#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_GT
+#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_LE
+#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort()
+#endif
+
+#ifndef TFLITE_CHECK_LT
+#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort()
+#endif
+
+// TODO(ahentz): Clean up.
+using uint8 = std::uint8_t;
+using int16 = std::int16_t;
+using uint16 = std::uint16_t;
+using int32 = std::int32_t;
+using uint32 = std::uint32_t;
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
new file mode 100644
index 0000000000..dea46cc120
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+
+namespace tflite {
+
+#ifdef __ANDROID__
+#include "ndk/sources/android/cpufeatures/cpu-features.h"
+
+// Runtime check for Neon support on Android.
+inline bool TestCPUFeatureNeon() {
+#ifdef __aarch64__
+ // ARM-64 always has NEON support.
+ return true;
+#else
+ static bool kUseAndroidNeon =
+ (android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM &&
+ android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_ARMv7 &&
+ android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON);
+ return kUseAndroidNeon;
+#endif // __aarch64__
+}
+
+#elif __ARM_NEON
+
+inline bool TestCPUFeatureNeon() {
+ return true;
+}
+
+#else
+
+inline bool TestCPUFeatureNeon() {
+ return false;
+}
+
+#endif
+
+} // namespace tflite
+
+// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both
+// enabled at build time and detected at runtime, or PortableSomeFunc(args)
+// otherwise.
+#ifdef __ARM_ARCH_5TE__
+// Neon isn't available at all on ARMv5.
+#define NEON_OR_PORTABLE(funcname, ...) Portable##funcname(__VA_ARGS__)
+#else
+#define NEON_OR_PORTABLE(funcname, ...) \
+ TestCPUFeatureNeon() ? Neon##funcname(__VA_ARGS__) \
+ : Portable##funcname(__VA_ARGS__)
+#endif
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
new file mode 100644
index 0000000000..974611f52a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -0,0 +1,987 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Implementation of float DepthwiseConv
+
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+struct FloatDepthwiseConvKernel {};
+
+#ifdef USE_NEON
+
+template <>
+struct FloatDepthwiseConvKernel<false, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(filter_ptr + 4 * i);
+ }
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the inputs
+ float32x4_t input[4];
+ for (int i = 0; i < 4; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 16;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
+ acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
+ acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
+ acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x4_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 8;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<false, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ const float32x2_t filters = vld1_f32(filter_ptr);
+ const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the inputs
+ float32x4_t input[4];
+ for (int i = 0; i < 4; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 16;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the inputs
+ float32x4_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ input_ptr += 8;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the inputs
+ const float32x4_t input = vld1q_f32(input_ptr);
+ input_ptr += 4;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc = vld1q_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filters_dup2);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle 1 output pixel at a time
+ for (; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ const float32x2_t input = vld1_f32(input_ptr);
+ input_ptr += 2;
+ // Load the accumulators from acc_buffer
+ float32x2_t acc = vld1_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmla_f32(acc, input, filters);
+ // Store the accumulators back to acc_buffer
+ vst1_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 16 input channels at a time.
+ for (; ic <= input_depth - 16; ic += 16) {
+ // Load the filters
+ float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0);
+ float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1);
+ float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2);
+ float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3);
+ local_filter_ptr += 16;
+ // Load the inputs
+ float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0);
+ float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1);
+ float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2);
+ float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3);
+ local_input_ptr += 16;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
+ float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
+ float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
+ float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
+ // Multiply-accumulate
+ acc_0 = vmlaq_f32(acc_0, input_0, filter_0);
+ acc_1 = vmlaq_f32(acc_1, input_1, filter_1);
+ acc_2 = vmlaq_f32(acc_2, input_2, filter_2);
+ acc_3 = vmlaq_f32(acc_3, input_3, filter_3);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= input_depth - 4; ic += 4) {
+ // Load the filters
+ float32x4_t filter;
+ filter = vld1q_f32(local_filter_ptr);
+ local_filter_ptr += 4;
+ // Load the inputs
+ float32x4_t input;
+ input = vld1q_f32(local_input_ptr);
+ local_input_ptr += 4;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc;
+ acc = vld1q_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filter);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ const float input_val = *local_input_ptr++;
+ const float filter_val = *local_filter_ptr++;
+ *acc_buffer_ptr++ += filter_val * input_val;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 2 input channels at a time.
+ for (; ic <= input_depth - 2; ic += 2) {
+ // Load the filters
+ float32x4_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 16;
+ // Load the inputs
+ const float32x2_t input = vld1_f32(local_input_ptr);
+ local_input_ptr += 2;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
+ acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
+ acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
+ acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 8;
+ // Load the inputs
+ const float input_val = *local_input_ptr++;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters
+ float32x4_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 16;
+ // Load the inputs
+ float32x4x2_t input_dup2[2];
+ for (int i = 0; i < 2; i++) {
+ const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
+ input_dup2[i] = vzipq_f32(input, input);
+ }
+ local_input_ptr += 8;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
+ acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
+ acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
+ acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= input_depth - 4; ic += 4) {
+ // Load the filters
+ float32x2_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1_f32(local_filter_ptr + 2 * i);
+ }
+ local_filter_ptr += 8;
+ // Load the inputs
+ const float32x4_t input = vld1q_f32(local_input_ptr);
+ local_input_ptr += 4;
+ // Load the accumulators from acc_buffer
+ float32x2_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
+ acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
+ acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
+ acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle 2 input channels at a time.
+ for (; ic <= input_depth - 2; ic += 2) {
+ // Load the filters
+ const float32x4_t filter = vld1q_f32(local_filter_ptr);
+ local_filter_ptr += 4;
+ // Load the inputs
+ const float32x2_t input = vld1_f32(local_input_ptr);
+ local_input_ptr += 2;
+ // Load the accumulators from acc_buffer
+ float32x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
+ acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
+ }
+ acc_buffer_ptr += 4;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ // Load the inputs
+ const float input_val = *local_input_ptr++;
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
+ }
+ local_filter_ptr += 2;
+ acc_buffer_ptr += 2;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 1, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(filter_ptr + 4 * i);
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ const float input_val = *input_ptr;
+ input_ptr += input_ptr_increment;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 1, 32> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0);
+ float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1);
+ float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2);
+ float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3);
+ float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4);
+ float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5);
+ float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6);
+ float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7);
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ const float input_val = *input_ptr;
+ input_ptr += input_ptr_increment;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0);
+ float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1);
+ float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2);
+ float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3);
+ float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4);
+ float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5);
+ float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6);
+ float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7);
+ // Multiply-accumulate
+ acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val);
+ acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val);
+ acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val);
+ acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val);
+ acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val);
+ acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val);
+ acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val);
+ acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3);
+ vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4);
+ vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5);
+ vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6);
+ vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7);
+ acc_buffer_ptr += 32;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 0, 16> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const float* local_filter_ptr = filter_ptr;
+ const float* local_input_ptr = input_ptr;
+ for (int ic = 0; ic < input_depth; ic++) {
+ // Load the filters
+ float32x4_t filter[4];
+ for (int i = 0; i < 4; i++) {
+ filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
+ }
+ local_filter_ptr += 16;
+ // Load the inputs
+ const float input_val = *local_input_ptr++;
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ // Load the filters
+ float32x4_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vld1q_f32(filter_ptr + 4 * i);
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x4_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vld1q_f32(input_ptr + 4 * i);
+ }
+ // Load the accumulators from acc_buffer
+ float32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ float32x2_t filter = vld1_f32(filter_ptr);
+ float32x4_t filter_x4 = vcombine_f32(filter, filter);
+ int outp = 0;
+
+ // Handle two output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the inputs
+ float32x2_t input_1 = vld1_f32(input_ptr);
+ input_ptr += input_ptr_increment;
+ float32x2_t input_2 = vld1_f32(input_ptr);
+ input_ptr += input_ptr_increment;
+ float32x4_t input = vcombine_f32(input_1, input_2);
+
+ // Load the accumulators from acc_buffer
+ float32x4_t acc = vld1q_f32(acc_buffer_ptr);
+
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filter_x4);
+
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x2_t input = vld1_f32(input_ptr);
+ input_ptr += input_ptr_increment;
+
+ // Load the accumulators from acc_buffer
+ float32x2_t acc = vld1_f32(acc_buffer_ptr);
+
+ // Multiply-accumulate
+ acc = vmla_f32(acc, input, filter);
+
+ // Store the accumulators back to acc_buffer
+ vst1_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct FloatDepthwiseConvKernel<true, 4, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const float* input_ptr, int input_ptr_increment,
+ const float* filter_ptr, float* acc_buffer_ptr) {
+ float32x4_t filter = vld1q_f32(filter_ptr);
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs
+ float32x4_t input = vld1q_f32(input_ptr);
+ // Load the accumulators from acc_buffer
+ float32x4_t acc = vld1q_f32(acc_buffer_ptr);
+ // Multiply-accumulate
+ acc = vmlaq_f32(acc, input, filter);
+ // Store the accumulators back to acc_buffer
+ vst1q_f32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+#endif
+
+// Accumulates the effect of one row of the filter, on a segment of one row
+// of the output, accessing the corresponding one row of the input.
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+ const float* input_data, int pad_width,
+ int depth_multiplier, int filter_width,
+ const float* filter_data,
+ int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, float* acc_buffer) {
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
+#endif
+ // Sanity check parameters. This is important in particular to ensure
+ // that we keep the number of template instantiations minimal, so we don't
+ // increase binary size unnecessarily.
+ static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
+ static_assert(kFixedInputDepth || kAllowStrided, "");
+ TFLITE_DCHECK(stride == 1 || kAllowStrided);
+ if (kFixedInputDepth) {
+ TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
+ }
+ if (kFixedDepthMultiplier) {
+ TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
+ }
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ const int input_ptr_increment = stride * input_depth;
+ const float* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ // For the current (filter_x, filter_y) point in the filter,
+ // compute the boundaries of the corresponding output row segment.
+ int out_x_loop_start_unclampled = 0;
+ int out_x_loop_end_unclampled = 0;
+ if (kAllowStrided) {
+ if (stride == 2) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 1) / 2;
+ } else if (stride == 4) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 3) / 4;
+ } else {
+ out_x_loop_start_unclampled =
+ (pad_width - filter_x + stride - 1) / stride;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + stride - 1) / stride;
+ }
+ } else {
+ out_x_loop_start_unclampled = pad_width - filter_x;
+ out_x_loop_end_unclampled = pad_width + input_width - filter_x;
+ }
+ // The kernel will have to iterate on the segment of the
+ // output row that starts at out_x_loop_start and out_x_loop_end.
+ const int out_x_loop_start =
+ std::max(out_x_buffer_start, out_x_loop_start_unclampled);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end, out_x_loop_end_unclampled);
+
+ float* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const float* input_ptr = input_data + in_x_origin * input_depth;
+ const int num_output_pixels = out_x_loop_end - out_x_loop_start;
+ FloatDepthwiseConvKernel<kAllowStrided, kFixedInputDepth,
+ kFixedDepthMultiplier>::Run(num_output_pixels,
+ input_depth,
+ depth_multiplier,
+ input_ptr,
+ input_ptr_increment,
+ filter_base_ptr,
+ acc_buffer_ptr);
+ filter_base_ptr += output_depth;
+ }
+}
+
+// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
+inline void FloatDepthwiseConvAccumRowGeneric(
+ int stride, int input_depth, int input_width, const float* input_data,
+ int pad_width, int depth_multiplier, int filter_width,
+ const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, float* acc_buffer) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
+#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ LOG(FATAL)
+ << "\n\n"
+ << "*****************************************************************\n"
+ << "* This tfmini inference code was about to use the slow generic\n"
+ << "* fallback implementation for a DepthwiseConv op, and we want you\n"
+ << "* to be aware of that so that you will know why you get terrible\n"
+ << "* performance.\n"
+ << "*\n"
+ << "* If you would like to carry on with the slow code, compile\n"
+ << "* with this preprocessor token defined:\n"
+ << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "*\n"
+ << "* The right thing to do, if you care about performance, is to add\n"
+ << "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
+ << "* The relevant parameters defining your case are:\n"
+ << "* stride = " << stride << "\n"
+ << "* input_depth = " << input_depth << "\n"
+ << "* depth_multiplier = " << depth_multiplier << "\n"
+ << "*\n"
+ << "* Please do not hesitate to contact benoitjacob@ with this\n"
+ << "* information.\n"
+ << "*****************************************************************\n";
+#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ const float* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int out_x_loop_start = std::max(
+ out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end,
+ (pad_width + input_width - filter_x + stride - 1) / stride);
+
+ float* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const float* input_ptr = input_data + in_x_origin * input_depth;
+ const int input_ptr_increment = (stride - 1) * input_depth;
+ for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
+ const float* filter_ptr = filter_base_ptr;
+ for (int ic = 0; ic < input_depth; ++ic) {
+ const float input_val = *input_ptr++;
+ for (int m = 0; m < depth_multiplier; m++) {
+ const float filter_val = *filter_ptr++;
+ *acc_buffer_ptr++ += filter_val * input_val;
+ }
+ }
+ input_ptr += input_ptr_increment;
+ }
+ filter_base_ptr += output_depth;
+ }
+}
+
+// Initializes the accumulator buffer with bias values.
+inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
+ const float* bias_data,
+ float* acc_buffer) {
+ // TODO(benoitjacob): This might need optimized specializations
+ // for small output_depth values, if that ever becomes an important
+ // case (like it was for some quantized DepthwiseConv cases).
+ for (int i = 0; i < num_output_pixels; i++) {
+ memcpy(acc_buffer + i * output_depth, bias_data,
+ sizeof(acc_buffer[0]) * output_depth);
+ }
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ static const int kAccBufferMaxSize = 2048;
+ float acc_buffer[kAccBufferMaxSize];
+ TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth);
+ const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
+ const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
+ TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
+ kAccBufferActualSize);
+ TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
+ TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
+
+ // row_accum_func will point to the core accumulation function to be used
+ // for this DepthwiseConv op.
+ using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric);
+ row_accum_func_t row_accum_func = nullptr;
+
+#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER) \
+ if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
+ (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ row_accum_func = \
+ FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER>; \
+ }
+
+#ifdef USE_NEON
+ // We go over our list of kernels by decreasing order of preference
+ // for the cases where multiple kernels could apply.
+
+ // Start with the fastest kernels: AllowStrided=false, fixed input depth.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
+
+ // Next come the strided kernels: AllowStrided=true, fixed input depth.
+ // They are a bit less efficient, but allow stride!=1.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
+
+ // Finally, the kernels allowing a variable input depth,
+ // these are the least efficient but most general kernels.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16)
+
+#endif // USE_NEON
+
+#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+
+ // No matching fast kernel found, use slow fallback.
+ if (!row_accum_func) {
+ row_accum_func = FloatDepthwiseConvAccumRowGeneric;
+ }
+
+ // Now that we have determined row_accum_func, we can start work.
+ float* output_ptr = output_data;
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
+ out_x_buffer_start += kOutputPixelsInAccBuffer) {
+ const int out_x_buffer_end = std::min(
+ output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
+ // We call a 'pixel' a group of activation that share all but the
+ // 'depth'/'channel' coordinate. num_output_pixels is the number of
+ // output pixels that we will accumulate in this loop iteration.
+ const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
+ // Initialize our local accumulator with the bias values, so we don't
+ // have to add them later.
+ DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
+ acc_buffer);
+ // Accumulation loop. Most of the time should be spent in here.
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ const int in_y = in_y_origin + filter_y;
+ row_accum_func(stride_width, input_depth, input_width,
+ input_data + in_y * input_dims.strides[2] +
+ b * input_dims.strides[3],
+ pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_dims.strides[2],
+ out_x_buffer_start, out_x_buffer_end, output_depth,
+ acc_buffer);
+ }
+ // Finished accumulating. Now store to destination.
+ const int num_output_values = output_depth * num_output_pixels;
+ int i = 0;
+// TODO(benoitjacob) optimized code goes here
+#ifdef USE_NEON
+ // Handle 16 values at a time
+ for (; i <= num_output_values - 16; i += 16) {
+ float32x4_t acc[4];
+ for (int k = 0; k < 4; k++) {
+ acc[k] = vld1q_f32(acc_buffer + i + 4 * k);
+ }
+ for (int k = 0; k < 4; k++) {
+ acc[k] = vmaxq_f32(
+ vdupq_n_f32(output_activation_min),
+ vminq_f32(vdupq_n_f32(output_activation_max), acc[k]));
+ }
+ for (int k = 0; k < 4; k++) {
+ vst1q_f32(output_ptr + 4 * k, acc[k]);
+ }
+ output_ptr += 16;
+ }
+ // Handle 4 values at a time
+ for (; i <= num_output_values - 4; i += 4) {
+ float32x4_t acc = vld1q_f32(acc_buffer + i);
+
+ acc = vmaxq_f32(vdupq_n_f32(output_activation_min),
+ vminq_f32(vdupq_n_f32(output_activation_max), acc));
+
+ vst1q_f32(output_ptr, acc);
+ output_ptr += 4;
+ }
+#endif
+ // Handle leftover values, one by one. This is very slow.
+ for (; i < num_output_values; i++) {
+ float acc = acc_buffer[i];
+ acc = std::max(output_activation_min,
+ std::min(output_activation_max, acc));
+
+ *output_ptr++ = acc;
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
new file mode 100644
index 0000000000..051ed2a2c4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -0,0 +1,1916 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Implementation of quantized DepthwiseConv
+
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+struct QuantizedDepthwiseConvKernel {};
+
+#ifdef USE_NEON
+template <>
+struct QuantizedDepthwiseConvKernel<true, 8, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8x2_t filter_u8;
+ filter_u8.val[0] = vld1_u8(filter_ptr);
+ filter_u8.val[1] = vld1_u8(filter_ptr + 8);
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])),
+ vdupq_n_s16(filter_offset));
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
+ }
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += input_ptr_increment;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]),
+ vget_low_s16(input_dup2.val[i]));
+ acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]),
+ vget_high_s16(input_dup2.val[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8[2];
+ for (int i = 0; i < 2; i++) {
+ input_u8[i] = vld1_u8(input_ptr + 8 * i);
+ }
+ input_ptr += 16;
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
+ }
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0]));
+ acc[1] =
+ vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0]));
+ acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1]));
+ acc[3] =
+ vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1]));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 1 output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[2];
+ acc[0] = vld1q_s32(acc_buffer_ptr);
+ acc[1] = vld1q_s32(acc_buffer_ptr + 4);
+
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input));
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input));
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc[0]);
+ vst1q_s32(acc_buffer_ptr + 4, acc[1]);
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 4, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter),
+ vget_low_s16(input_dup2.val[i]));
+ acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter),
+ vget_high_s16(input_dup2.val[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x4x2_t input_dup2 = vzip_s16(input, input);
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]);
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 2, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+ int outp = 0;
+ // Handle two output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[8];
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate.
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
+ acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2);
+ acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2);
+ acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3);
+ acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3);
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 8; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 32;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += 2;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
+
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 2, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
+ acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
+ acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += 2;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x4_t input_dup2 = vzip_s16(input, input).val[0];
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input_dup2);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8[2];
+ for (int i = 0; i < 2; i++) {
+ input_u8[i] = vld1_u8(input_ptr + 8 * i);
+ }
+ input_ptr += 16;
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
+ }
+
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0]));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0]));
+ acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1]));
+ acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1]));
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input));
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer.
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ // Handle 1 output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x2_t acc = vld1_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += 2;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
+ // Store the accumulators back to acc_buffer.
+ vst1_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 1, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Duplicate the input values, 2-fold
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
+ acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
+ acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
+ acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x2_t acc = vld1_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ const uint32 input = *input_ptr++ + input_offset;
+
+ // Multiply-accumulate
+ acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input));
+ // Store the accumulators back to acc_buffer
+ vst1_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 1, 4> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 8 output pixels at a time.
+ for (; outp <= num_output_pixels - 8; outp += 8) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[8];
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0);
+ acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1);
+ acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2);
+ acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3);
+ acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0);
+ acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1);
+ acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2);
+ acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3);
+
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 8; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 32;
+ }
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], filter, input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], filter, input, 1);
+ acc[2] = vmlal_lane_s16(acc[2], filter, input, 2);
+ acc[3] = vmlal_lane_s16(acc[3], filter, input, 3);
+
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ const uint32 input = *input_ptr++ + input_offset;
+
+ // Multiply-accumulate
+ acc = vmlal_n_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 4, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+ // Handle 4 output pixels at a time.
+ for (; outp <= num_output_pixels - 4; outp += 4) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Load the inputs, add input_offset.
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i);
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ }
+ input_ptr += 16;
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] =
+ vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i]));
+ acc[2 * i + 1] =
+ vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc;
+ acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 4, 4> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+
+ int outp = 0;
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[8];
+ for (int i = 0; i < 8; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]),
+ vget_low_s16(input), 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]),
+ vget_low_s16(input), 1);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]),
+ vget_low_s16(input), 2);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]),
+ vget_low_s16(input), 3);
+ acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]),
+ vget_high_s16(input), 0);
+ acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]),
+ vget_high_s16(input), 1);
+ acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]),
+ vget_high_s16(input), 2);
+ acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]),
+ vget_high_s16(input), 3);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 8; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 32;
+ }
+ // Handle one output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ input_ptr += 4;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate
+ acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
+ acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1);
+ acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2);
+ acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 0, 3> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // We will have to duplicate bytes in a NEON register, 3-fold.
+ // We will do that by register-level table-look-up using VTBL instructions.
+ // Here we prepare the registers containing the table-lookup indices.
+ static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2},
+ {2, 3, 3, 3, 4, 4, 4, 5},
+ {5, 5, 6, 6, 6, 7, 7, 7}};
+ uint8x8_t dup3_indices[3];
+ for (int i = 0; i < 3; i++) {
+ dup3_indices[i] = vld1_u8(dup3_indices_array[i]);
+ }
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const uint8* local_filter_ptr = filter_ptr;
+ const uint8* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[3];
+ uint8x8x3_t filter_u8;
+ filter_u8.val[0] = vld1_u8(local_filter_ptr);
+ filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
+ filter_u8.val[2] = vld1_u8(local_filter_ptr + 16);
+ local_filter_ptr += 24;
+ for (int i = 0; i < 3; i++) {
+ const int16x8_t filter_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+ // Load the inputs, duplicate 3-fold, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
+ local_input_ptr += 8;
+
+ uint8x8_t input_u8_dup3[3];
+ for (int i = 0; i < 3; i++) {
+ input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]);
+ }
+ int16x8_t input_dup3[3];
+ for (int i = 0; i < 3; i++) {
+ const int16x8_t input_s16_dup3 =
+ vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i]));
+ input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset));
+ }
+ // Load the accumulators from acc_buffer
+ int32x4x3_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
+ acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16);
+ }
+ // Multiply-accumulate
+ for (int j = 0; j < 3; j++) {
+ acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]),
+ vget_low_s16(filter[j]));
+ acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]),
+ vget_high_s16(filter[j]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]);
+ }
+ acc_buffer_ptr += 24;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ const int16 input_val = *local_input_ptr++ + input_offset;
+ for (int i = 0; i < 3; i++) {
+ const int16 filter_val = local_filter_ptr[i] + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ local_filter_ptr += 3;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 0, 2> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const uint8* local_filter_ptr = filter_ptr;
+ const uint8* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters, add filter_offset.
+ int16x8_t filter[2];
+ uint8x8x2_t filter_u8;
+ filter_u8.val[0] = vld1_u8(local_filter_ptr);
+ filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
+ local_filter_ptr += 16;
+ for (int i = 0; i < 2; i++) {
+ const int16x8_t filter_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
+ filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ }
+ // Load the inputs, add input_offset, duplicate 2-fold.
+ const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
+ local_input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ const int16x8x2_t input_dup2 = vzipq_s16(input, input);
+ // Load the accumulators from acc_buffer.
+ int32x4x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
+ }
+ // Multiply-accumulate.
+ for (int j = 0; j < 2; j++) {
+ acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]),
+ vget_low_s16(input_dup2.val[j]));
+ acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]),
+ vget_high_s16(input_dup2.val[j]));
+ }
+ // Store the accumulators back to acc_buffer.
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
+ vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ // Load the inputs.
+ const int16 input_val = *local_input_ptr++ + input_offset;
+ for (int i = 0; i < 2; i++) {
+ const int16 filter_val = local_filter_ptr[i] + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ local_filter_ptr += 2;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 0, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ const uint8* local_filter_ptr = filter_ptr;
+ const uint8* local_input_ptr = input_ptr;
+ int ic = 0;
+ // Handle 16 input channels at a time.
+ for (; ic <= input_depth - 16; ic += 16) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0);
+ uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1);
+ local_filter_ptr += 16;
+ int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
+ int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
+ filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
+ filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0);
+ uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1);
+ local_input_ptr += 16;
+ int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
+ int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
+ input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
+ input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
+ // Load the accumulators from acc_buffer
+ int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
+ int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
+ int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
+ int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
+ acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0));
+ acc_1 =
+ vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0));
+ acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1));
+ acc_3 =
+ vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1));
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
+ acc_buffer_ptr += 16;
+ }
+ // Handle 8 input channels at a time.
+ for (; ic <= input_depth - 8; ic += 8) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr);
+ local_filter_ptr += 8;
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter =
+ vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
+ local_input_ptr += 8;
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ // Handle one input channel at a time.
+ for (; ic < input_depth; ic++) {
+ const int16 input_val = *local_input_ptr++ + input_offset;
+ const int16 filter_val = *local_filter_ptr++ + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 16, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8[2];
+ for (int i = 0; i < 2; i++) {
+ filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
+ }
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8[2];
+ for (int i = 0; i < 2; i++) {
+ input_u8[i] = vld1_u8(input_ptr + 8 * i);
+ }
+ input_ptr += input_ptr_increment;
+ int16x8_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
+ }
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]),
+ vget_low_s16(filter[i]));
+ acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]),
+ vget_high_s16(filter[i]));
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 8, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
+ const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs, add input_offset.
+ const uint8x8_t input_u8 = vld1_u8(input_ptr);
+ const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
+ const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
+ acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ input_ptr += input_ptr_increment;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 1, 16> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8[2];
+ for (int i = 0; i < 2; i++) {
+ filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
+ }
+ int16x8_t filter[2];
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
+ }
+ for (int i = 0; i < 2; i++) {
+ filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
+ }
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ uint8 input_u8 = *input_ptr;
+ input_ptr += input_ptr_increment;
+ int16 input = static_cast<int16>(input_u8 + input_offset);
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ for (int i = 0; i < 2; i++) {
+ acc[2 * i + 0] =
+ vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input);
+ acc[2 * i + 1] =
+ vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input);
+ }
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 4; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 16;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 1, 32> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
+ uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
+ uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2);
+ uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3);
+ int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
+ int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
+ int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2));
+ int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3));
+ filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
+ filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
+ filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset));
+ filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset));
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ uint8 input_u8 = *input_ptr;
+ input_ptr += input_ptr_increment;
+ int16 input = static_cast<int16>(input_u8 + input_offset);
+ // Load the accumulators from acc_buffer
+ int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
+ int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
+ int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
+ int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
+ int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
+ int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5);
+ int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6);
+ int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7);
+ // Multiply-accumulate
+ acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
+ acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
+ acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
+ acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
+ acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input);
+ acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input);
+ acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input);
+ acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
+ vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
+ vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
+ vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5);
+ vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6);
+ vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7);
+ acc_buffer_ptr += 32;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 1, 8> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
+ const int16x8_t filter = vaddq_s16(
+ vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset));
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ uint8 input_u8 = *input_ptr;
+ input_ptr += input_ptr_increment;
+ int16 input = static_cast<int16>(input_u8 + input_offset);
+ // Load the accumulators from acc_buffer
+ int32x4_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
+ }
+ // Multiply-accumulate
+ acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input);
+ acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input);
+ // Store the accumulators back to acc_buffer
+ for (int i = 0; i < 2; i++) {
+ vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
+ }
+ acc_buffer_ptr += 8;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 2, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+
+ // Handle 2 output pixels at a time.
+ for (; outp <= num_output_pixels - 2; outp += 2) {
+ // Load the accumulators from acc_buffer.
+ int32x4_t acc = vld1q_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint16x4_t input_u16 = vdup_n_u16(0);
+ input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
+ input_u16, 0);
+ input_ptr += input_ptr_increment;
+ input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
+ input_u16, 1);
+ input_ptr += input_ptr_increment;
+ const int16x4_t input_s16 = vreinterpret_s16_u16(
+ vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16))));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer.
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+
+ // Handle 1 output pixel at a time.
+ for (; outp < num_output_pixels; outp++) {
+ // Load the accumulators from acc_buffer.
+ int32x2_t acc = vld1_s32(acc_buffer_ptr);
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_ptr += input_ptr_increment;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+
+ // Multiply-accumulate.
+ acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
+ // Store the accumulators back to acc_buffer.
+ vst1_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 2;
+ }
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<true, 4, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ if (num_output_pixels <= 0) {
+ return;
+ }
+
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8 = vdup_n_u8(0);
+ filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
+ filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
+ filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
+ filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
+ const int16x4_t filter_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
+ const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
+
+ int outp = 0;
+
+ // Handle one output pixel at a time until second to the last pixel. Second
+ // to the last because we read eight input pixels while only processing
+ // four.
+ for (; outp < num_output_pixels - 1; outp++) {
+ // Load the accumulators from acc_buffer
+ int32x4_t acc;
+ acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vld1_u8(input_ptr);
+ input_ptr += input_ptr_increment;
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ acc_buffer_ptr += 4;
+ }
+
+ // Handle the last output pixel.
+ // Load the accumulators from acc_buffer
+ int32x4_t acc;
+ acc = vld1q_s32(acc_buffer_ptr);
+
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8 = vdup_n_u8(0);
+ input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
+ input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
+ input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
+ input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
+ const int16x4_t input_s16 =
+ vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
+ const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
+ // Multiply-accumulate
+ acc = vmlal_s16(acc, filter, input);
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr, acc);
+ }
+};
+
+template <>
+struct QuantizedDepthwiseConvKernel<false, 12, 1> {
+ static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
+ const uint8* input_ptr, int16 input_offset,
+ int input_ptr_increment, const uint8* filter_ptr,
+ int16 filter_offset, int32* acc_buffer_ptr) {
+ // Load the filters, add filter_offset.
+ uint8x8_t filter_u8_0 = vld1_u8(filter_ptr);
+ uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4);
+ int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
+ int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
+ filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset));
+ filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset));
+ int16x4_t filter_0 = vget_low_s16(filter_s16_0);
+ int16x4_t filter_1 = vget_high_s16(filter_s16_0);
+ int16x4_t filter_2 = vget_high_s16(filter_s16_1);
+
+ // Handle one output pixel at a time.
+ for (int outp = 0; outp < num_output_pixels; outp++) {
+ // Load the inputs, add input_offset.
+ uint8x8_t input_u8_0 = vld1_u8(input_ptr);
+ uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4);
+ input_ptr += input_ptr_increment;
+ int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
+ int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
+ input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
+ input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
+
+ // Load the accumulators from acc_buffer
+ int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
+ int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
+ int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
+
+ // Multiply-accumulate
+ acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0);
+ acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1);
+ acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2);
+
+ // Store the accumulators back to acc_buffer
+ vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
+ vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
+ vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
+
+ acc_buffer_ptr += 12;
+ }
+ }
+};
+#endif
+
+// Accumulates the effect of one row of the filter, on a segment of one row
+// of the output, accessing the corresponding one row of the input.
+template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
+void QuantizedDepthwiseConvAccumRow(
+ int stride, int input_depth, int input_width, const uint8* input_data,
+ int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
+ const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+#ifdef GEMMLOWP_PROFILING
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
+#endif
+ // Sanity check parameters. This is important in particular to ensure
+ // that we keep the number of template instantiations minimal, so we don't
+ // increase binary size unnecessarily.
+ static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
+ static_assert(kFixedInputDepth || kAllowStrided, "");
+ TFLITE_DCHECK(stride == 1 || kAllowStrided);
+ if (kFixedInputDepth) {
+ TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
+ }
+ if (kFixedDepthMultiplier) {
+ TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
+ }
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ const int input_ptr_increment = stride * input_depth;
+ const uint8* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ // For the current (filter_x, filter_y) point in the filter,
+ // compute the boundaries of the corresponding output row segment.
+ int out_x_loop_start_unclampled = 0;
+ int out_x_loop_end_unclampled = 0;
+ if (kAllowStrided) {
+ if (stride == 2) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 1) / 2;
+ } else if (stride == 4) {
+ out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + 3) / 4;
+ } else {
+ out_x_loop_start_unclampled =
+ (pad_width - filter_x + stride - 1) / stride;
+ out_x_loop_end_unclampled =
+ (pad_width + input_width - filter_x + stride - 1) / stride;
+ }
+ } else {
+ out_x_loop_start_unclampled = pad_width - filter_x;
+ out_x_loop_end_unclampled = pad_width + input_width - filter_x;
+ }
+ // The kernel will have to iterate on the segment of the
+ // output row that starts at out_x_loop_start and out_x_loop_end.
+ const int out_x_loop_start =
+ std::max(out_x_buffer_start, out_x_loop_start_unclampled);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end, out_x_loop_end_unclampled);
+
+ int32* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const uint8* input_ptr = input_data + in_x_origin * input_depth;
+ const int num_output_pixels = out_x_loop_end - out_x_loop_start;
+ QuantizedDepthwiseConvKernel<
+ kAllowStrided, kFixedInputDepth,
+ kFixedDepthMultiplier>::Run(num_output_pixels, input_depth,
+ depth_multiplier, input_ptr, input_offset,
+ input_ptr_increment, filter_base_ptr,
+ filter_offset, acc_buffer_ptr);
+ filter_base_ptr += output_depth;
+ }
+}
+
+// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
+inline void QuantizedDepthwiseConvAccumRowGeneric(
+ int stride, int input_depth, int input_width, const uint8* input_data,
+ int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
+ const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
+#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ LOG(FATAL)
+ << "\n\n"
+ << "*****************************************************************\n"
+ << "* This tfmini inference code was about to use the slow generic\n"
+ << "* fallback implementation for a DepthwiseConv op, and we want you\n"
+ << "* to be aware of that so that you will know why you get terrible\n"
+ << "* performance.\n"
+ << "*\n"
+ << "* If you would like to carry on with the slow code, compile\n"
+ << "* with this preprocessor token defined:\n"
+ << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "*\n"
+ << "* The right thing to do, if you care about performance, is to add\n"
+ << "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
+ << "* The relevant parameters defining your case are:\n"
+ << "* stride = " << stride << "\n"
+ << "* input_depth = " << input_depth << "\n"
+ << "* depth_multiplier = " << depth_multiplier << "\n"
+ << "*\n"
+ << "* Please do not hesitate to contact benoitjacob@ with this\n"
+ << "* information.\n"
+ << "*****************************************************************\n";
+#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+ const uint8* filter_base_ptr = filter_data;
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int out_x_loop_start = std::max(
+ out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
+ const int out_x_loop_end =
+ std::min(out_x_buffer_end,
+ (pad_width + input_width - filter_x + stride - 1) / stride);
+
+ int32* acc_buffer_ptr =
+ acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
+ const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const uint8* input_ptr = input_data + in_x_origin * input_depth;
+ const int input_ptr_increment = (stride - 1) * input_depth;
+ for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
+ const uint8* filter_ptr = filter_base_ptr;
+ for (int ic = 0; ic < input_depth; ++ic) {
+ const int16 input_val = *input_ptr++ + input_offset;
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int16 filter_val = *filter_ptr++ + filter_offset;
+ *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
+ }
+ }
+ input_ptr += input_ptr_increment;
+ }
+ filter_base_ptr += output_depth;
+ }
+}
+
+// Initializes the accumulator buffer with bias values.
+inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
+ const int32* bias_data,
+ int32* acc_buffer) {
+ int i = 0;
+#ifdef USE_NEON
+ if (output_depth == 1) {
+ const int32x4_t b = vdupq_n_s32(bias_data[0]);
+ for (; i <= num_output_pixels - 16; i += 16) {
+ vst1q_s32(acc_buffer + i + 0, b);
+ vst1q_s32(acc_buffer + i + 4, b);
+ vst1q_s32(acc_buffer + i + 8, b);
+ vst1q_s32(acc_buffer + i + 12, b);
+ }
+ for (; i <= num_output_pixels - 4; i += 4) {
+ vst1q_s32(acc_buffer + i, b);
+ }
+ } else if (output_depth == 2) {
+ int32x4_t b = vdupq_n_s32(bias_data[0]);
+ b = vsetq_lane_s32(bias_data[1], b, 1);
+ b = vsetq_lane_s32(bias_data[1], b, 3);
+ for (; i <= num_output_pixels - 8; i += 8) {
+ vst1q_s32(acc_buffer + 2 * i + 0, b);
+ vst1q_s32(acc_buffer + 2 * i + 4, b);
+ vst1q_s32(acc_buffer + 2 * i + 8, b);
+ vst1q_s32(acc_buffer + 2 * i + 12, b);
+ }
+ for (; i <= num_output_pixels - 2; i += 2) {
+ vst1q_s32(acc_buffer + 2 * i, b);
+ }
+ } else if (output_depth == 4) {
+ const int32x4_t b = vld1q_s32(bias_data);
+ for (; i <= num_output_pixels - 4; i += 4) {
+ vst1q_s32(acc_buffer + 4 * i + 0, b);
+ vst1q_s32(acc_buffer + 4 * i + 4, b);
+ vst1q_s32(acc_buffer + 4 * i + 8, b);
+ vst1q_s32(acc_buffer + 4 * i + 12, b);
+ }
+ for (; i < num_output_pixels; i++) {
+ vst1q_s32(acc_buffer + 4 * i, b);
+ }
+ } else if (output_depth == 8) {
+ const int32x4_t b0 = vld1q_s32(bias_data);
+ const int32x4_t b1 = vld1q_s32(bias_data + 4);
+ for (; i <= num_output_pixels - 2; i += 2) {
+ vst1q_s32(acc_buffer + 8 * i + 0, b0);
+ vst1q_s32(acc_buffer + 8 * i + 4, b1);
+ vst1q_s32(acc_buffer + 8 * i + 8, b0);
+ vst1q_s32(acc_buffer + 8 * i + 12, b1);
+ }
+ for (; i < num_output_pixels; i++) {
+ vst1q_s32(acc_buffer + 8 * i + 0, b0);
+ vst1q_s32(acc_buffer + 8 * i + 4, b1);
+ }
+ } else if (output_depth == 16) {
+ const int32x4_t b0 = vld1q_s32(bias_data);
+ const int32x4_t b1 = vld1q_s32(bias_data + 4);
+ const int32x4_t b2 = vld1q_s32(bias_data + 8);
+ const int32x4_t b3 = vld1q_s32(bias_data + 12);
+ for (; i < num_output_pixels; i++) {
+ vst1q_s32(acc_buffer + 16 * i + 0, b0);
+ vst1q_s32(acc_buffer + 16 * i + 4, b1);
+ vst1q_s32(acc_buffer + 16 * i + 8, b2);
+ vst1q_s32(acc_buffer + 16 * i + 12, b3);
+ }
+ }
+#endif
+ for (; i < num_output_pixels; i++) {
+ memcpy(acc_buffer + i * output_depth, bias_data,
+ sizeof(acc_buffer[0]) * output_depth);
+ }
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ static const int kAccBufferMaxSize = 2048;
+ int32 acc_buffer[kAccBufferMaxSize];
+ TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth);
+ const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
+ const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
+ TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
+ kAccBufferActualSize);
+ TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
+ TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
+
+ // row_accum_func will point to the core accumulation function to be used
+ // for this DepthwiseConv op.
+ using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric);
+ row_accum_func_t row_accum_func = nullptr;
+
+#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER) \
+ if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
+ (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ row_accum_func = \
+ QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
+ FIXED_DEPTH_MULTIPLIER>; \
+ }
+
+#ifdef USE_NEON
+ // We go over our list of kernels by decreasing order of preference
+ // for the cases where multiple kernels could apply.
+
+ // Start with the fastest kernels: AllowStrided=false, fixed input depth.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1)
+
+ // Next come the strided kernels: AllowStrided=true, fixed input depth.
+ // They are a bit less efficient, but allow stride!=1.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
+
+ // Finally, the kernels allowing a variable input depth,
+ // these are the least efficient but most general kernels.
+
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
+ TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3)
+#endif // USE_NEON
+
+ // No matching fast kernel found, use slow fallback.
+ if (!row_accum_func) {
+ row_accum_func = QuantizedDepthwiseConvAccumRowGeneric;
+ }
+
+#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+
+ // Now that we have determined row_accum_func, we can start work.
+ uint8* output_ptr = output_data;
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
+ out_x_buffer_start += kOutputPixelsInAccBuffer) {
+ const int out_x_buffer_end = std::min(
+ output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
+ // We call a 'pixel' a group of activation that share all but the
+ // 'depth'/'channel' coordinate. num_output_pixels is the number of
+ // output pixels that we will accumulate in this loop iteration.
+ const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
+ // Initialize our local accumulator with the bias values, so we don't
+ // have to add them later.
+ DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
+ acc_buffer);
+ // Accumulation loop. Most of the time should be spent in here.
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ const int in_y = in_y_origin + filter_y;
+ row_accum_func(
+ stride_width, input_depth, input_width,
+ input_data + in_y * input_dims.strides[2] +
+ b * input_dims.strides[3],
+ input_offset, pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_dims.strides[2], filter_offset,
+ out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
+ }
+ // Finished accumulating int32 values. Now need to convert them to
+ // the final 8bit form and store them.
+ gemmlowp::ScopedProfilingLabel label("downquantize+store");
+ const int num_output_values = output_depth * num_output_pixels;
+ int i = 0;
+#ifdef USE_NEON
+ using gemmlowp::RoundingDivideByPOT;
+ const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
+ const int32x4_t output_activation_min_vec =
+ vdupq_n_s32(output_activation_min);
+ const int32x4_t output_activation_max_vec =
+ vdupq_n_s32(output_activation_max);
+ // Handle 16 values at once.
+ // This allows us to issue 4 mutually independent int32
+ // multiplications (vqrdmulh), which should alleviate most of their
+ // high latency.
+ for (; i <= num_output_values - 16; i += 16) {
+ int32x4_t acc[4];
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
+ }
+
+ // Fixed-point multiplication.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
+ }
+ for (int j = 0; j < 4; j++) {
+ acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ }
+ // Add the output offset.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vaddq_s32(acc[j], output_offset_vec);
+ }
+ // Apply the activation function.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vmaxq_s32(acc[j], output_activation_min_vec);
+ }
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vminq_s32(acc[j], output_activation_max_vec);
+ }
+ // Saturating cast to uint8 and store to destination.
+ int16x4_t acc_s16[4];
+ for (int j = 0; j < 4; j++) {
+ acc_s16[j] = vqmovn_s32(acc[j]);
+ }
+ const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]);
+ const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]);
+ const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0);
+ const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1);
+ vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1));
+ output_ptr += 16;
+ }
+ // Handle 8 values at once.
+ // Not as good as 16 (now we're only issuing 2 mutually independent
+ // vqrdmulh instructions, so we're probably paying for their high
+ // latency).
+ for (; i <= num_output_values - 8; i += 8) {
+ int32x4_t acc0 = vld1q_s32(acc_buffer + i);
+ int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
+ // Fixed-point multiplication.
+ acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
+ acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
+ // Rounding right shift.
+ acc0 = RoundingDivideByPOT(acc0, output_shift);
+ acc1 = RoundingDivideByPOT(acc1, output_shift);
+ // Add the output offset.
+ acc0 = vaddq_s32(acc0, output_offset_vec);
+ acc1 = vaddq_s32(acc1, output_offset_vec);
+ // Apply the activation function.
+ acc0 = vmaxq_s32(acc0, output_activation_min_vec);
+ acc1 = vmaxq_s32(acc1, output_activation_min_vec);
+ acc0 = vminq_s32(acc0, output_activation_max_vec);
+ acc1 = vminq_s32(acc1, output_activation_max_vec);
+ // Saturating cast to uint8 and store to destination.
+ const int16x4_t acc0_s16 = vqmovn_s32(acc0);
+ const int16x4_t acc1_s16 = vqmovn_s32(acc1);
+ const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16);
+ const uint8x8_t res_u8 = vqmovun_s16(res_s16);
+ vst1_u8(output_ptr, res_u8);
+ output_ptr += 8;
+ }
+ // Handle 4 values at once. Now we're paying the full price of the
+ // high latency of vqrdmulh. Also, storing only 4 bytes at the end
+ // (without any alignment) can only be done 1 byte at a time.
+ // Yet, that is still worth doing to minimize the amount of leftover
+ // that will have to go through the very slow scalar code.
+ for (; i <= num_output_values - 4; i += 4) {
+ int32x4_t acc = vld1q_s32(acc_buffer + i);
+ // Fixed-point multiplication.
+ acc = vqrdmulhq_n_s32(acc, output_multiplier);
+ // Rounding right shift.
+ acc = RoundingDivideByPOT(acc, output_shift);
+ // Add the output offset.
+ acc = vaddq_s32(acc, output_offset_vec);
+ // Apply the activation function.
+ acc = vmaxq_s32(acc, output_activation_min_vec);
+ acc = vminq_s32(acc, output_activation_max_vec);
+ // Saturating cast to uint8 and store to destination.
+ const int16x4_t acc_s16 = vqmovn_s32(acc);
+ const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
+ const uint8x8_t res_u8 = vqmovun_s16(res_s16);
+ vst1_lane_u8(output_ptr + 0, res_u8, 0);
+ vst1_lane_u8(output_ptr + 1, res_u8, 1);
+ vst1_lane_u8(output_ptr + 2, res_u8, 2);
+ vst1_lane_u8(output_ptr + 3, res_u8, 3);
+ output_ptr += 4;
+ }
+#endif // USE_NEON
+
+ // Handle leftover values, one by one. This is very slow.
+ for (; i < num_output_values; i++) {
+ int32 acc = acc_buffer[i];
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ *output_ptr++ = static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
new file mode 100644
index 0000000000..8004c24a99
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
@@ -0,0 +1,231 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h.
+// TODO(petewarden) - move this to a common location in Eigen itself.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+// NOTE: Eigen is slightly different internally and externally. We need to
+// hack the unsupported/Eigen/CXX11/Tensor header instantiation macros at
+// specific places, so we need two copies of the hacked file, one for
+// internal and one for external.
+// If you have trouble simply undef out the reducer macro e.g.
+// TFLITE_REDUCE_INSTANTIATIONS_GOOGLE, but be aware this will make
+// the binary much bigger!
+#define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE
+#define Eigen EigenForTFLite
+#if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE)
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h"
+#elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE)
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h"
+#else
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#endif
+
+
+namespace Eigen {
+
+/** SpatialConvolution
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a 2D convolution over a multichannel input image.
+ *
+ * The input parameter is expected to be a tensor with a rank of 3 or more
+ * (channels, height, width, and optionally others)
+ * The kernel parameter is expected to be a 4D tensor (filters, channels,
+ * kernel_height, kernel_width)
+ * The input and the kernel must both be in col-major layout. The result will
+ * also be in col-major layout.
+ *
+ * If col_in_stride, row_in_stride > 1, then applies convolution with holes
+ * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
+ * pixels.
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the
+ * input. The dimensions of the result will be filters, height, width (and
+ * others if applicable).
+ *
+ * It is possible to swap the order of the width and height dimensions provided
+ * that the same order is used in the input, the kernel, and the output.
+ *
+ */
+template <typename Input, typename Kernel>
+EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
+ internal::traits<Input>::Layout == ColMajor,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>,
+ 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const Kernel>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic,
+ const Input> > > >,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>,
+ 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const Kernel> > > >::type
+ SpatialConvolution(const Input& input, const Kernel& kernel,
+ const DenseIndex row_stride = 1,
+ const DenseIndex col_stride = 1,
+ const PaddingType padding_type = PADDING_SAME,
+ const DenseIndex row_in_stride = 1,
+ const DenseIndex col_in_stride = 1) {
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar,
+ internal::traits<Input>::NumDimensions,
+ internal::traits<Input>::Layout, TensorIndex> >
+ in(input);
+ TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
+ internal::traits<Kernel>::NumDimensions,
+ internal::traits<Kernel>::Layout, TensorIndex> >
+ kern(kernel);
+
+ EIGEN_STATIC_ASSERT(
+ internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+ const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ const int NumDims = internal::traits<Input>::NumDimensions;
+
+ // Number of filters to apply. This is the same as the output depth of the
+ // result
+ const TensorIndex kernelFilters =
+ isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
+ // Number of channels. This is the same as the input depth.
+ const TensorIndex kernelChannels =
+ isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
+ const TensorIndex kernelRows =
+ isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
+ const TensorIndex kernelCols =
+ isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
+
+ const DenseIndex kernelRowsEff =
+ kernelRows + (kernelRows - 1) * (row_in_stride - 1);
+ const DenseIndex kernelColsEff =
+ kernelCols + (kernelCols - 1) * (col_in_stride - 1);
+
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+
+ const TensorIndex InputRows =
+ isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex InputCols =
+ isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+
+ TensorIndex out_height;
+ TensorIndex out_width;
+ switch (padding_type) {
+ case PADDING_VALID:
+ out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) /
+ static_cast<float>(row_stride));
+ out_width = numext::ceil((InputCols - kernelColsEff + 1.f) /
+ static_cast<float>(col_stride));
+ break;
+ case PADDING_SAME:
+ out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
+ out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
+ break;
+ default:
+ // Initialize unused variables to avoid a compiler warning
+ out_height = 0;
+ out_width = 0;
+ eigen_assert(false && "unexpected padding");
+ }
+
+ // Molds the output of the patch extraction code into a 2d tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_height * out_width;
+ for (int i = 3; i < NumDims; ++i) {
+ pre_contract_dims[1] *= in.dimension(i);
+ }
+ } else {
+ pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_height * out_width;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ pre_contract_dims[0] *= in.dimension(i);
+ }
+ }
+
+ // Molds the output of the contraction into the shape expected by the used
+ // (assuming this is ColMajor):
+ // - 1st dim: kernel filters
+ // - 2nd dim: output height
+ // - 3rd dim: output width
+ // - 4th dim and beyond: everything else including batch size
+ DSizes<TensorIndex, NumDims> post_contract_dims;
+ if (isColMajor) {
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = out_height;
+ post_contract_dims[2] = out_width;
+ for (int i = 3; i < NumDims; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ } else {
+ post_contract_dims[NumDims - 1] = kernelFilters;
+ post_contract_dims[NumDims - 2] = out_height;
+ post_contract_dims[NumDims - 3] = out_width;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ }
+
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
+ } else {
+ kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
+ kernel_dims[1] = kernelFilters;
+ }
+ // TODO(yangke): choose() is defined in TensorContraction.h -- consider
+ // moving it to somewhere more "common".
+ return
+ input
+ .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
+ row_in_stride, col_in_stride, padding_type)
+ .reshape(pre_contract_dims)
+ .contract(kernel.reshape(kernel_dims), contract_dims)
+ .reshape(post_contract_dims);
+}
+
+} // end namespace Eigen
+
+// clang-format on
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
new file mode 100644
index 0000000000..7f78f69360
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
@@ -0,0 +1,143 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
+
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+// clang-format off
+
+#include <stdint.h>
+
+#include <cstddef>
+#include <cstring>
+#include <cmath>
+#include <random>
+#include <atomic>
+#include <condition_variable> // NOLINT(build/c++11)
+#include <mutex> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+#include <functional>
+
+#ifdef _WIN32
+#include <winbase.h>
+#elif defined(__APPLE__)
+#include <mach/mach_time.h>
+#else
+#include <time.h>
+#endif
+
+
+// Because some programs may link Eigen in through other frameworks with
+// different flags, we can run into multiple definition issues if we don't have
+// a private namespace for our versions. This is a nasty hack, but a similar
+// approach is used elsewhere to handle the problem, so it should be stable.
+#define Eigen EigenForTFLite
+
+#include "Eigen/src/Core/util/StaticAssert.h"
+#include "unsupported/Eigen/CXX11/Core"
+#include "unsupported/Eigen/SpecialFunctions"
+
+#include "Eigen/src/Core/util/DisableStupidWarnings.h"
+
+#include "Eigen/Core"
+
+// Beware: the order of the include matters to some compilers. For example
+// TensorIndexList.h should be included before TensorDimensions.h in order to
+// use index lists to encode tensor dimensions when compiling with llvm.
+// We're defining this ourselves rather than using the Eigen Tensor header file
+// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to
+// reduce binary size.
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
+#undef TENSOR_CONTRACTION_DISPATCH
+#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
+ if (this->m_lhs_inner_dim_contiguous && \
+ this->m_rhs_inner_dim_contiguous && \
+ !this->m_rhs_inner_dim_reordered) { \
+ METHOD<true, true, false, ALIGNMENT> ARGS; \
+ } else { \
+ eigen_assert(false && "Unsupported contraction formats"); \
+ }
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
+
+#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
new file mode 100644
index 0000000000..1d5c316194
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
@@ -0,0 +1,167 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is essentially unsupported/CXX11/Eigen/Tensor.h
+// TODO(petewarden) - move this to a common location in Eigen itself.
+
+// clang-format off
+
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+
+
+#include "Eigen/Core"
+
+#if defined(EIGEN_USE_SYCL)
+#undef min
+#undef max
+#undef isnan
+#undef isinf
+#undef isfinite
+#include <CL/sycl.hpp>
+#include <iostream>
+#include <map>
+#include <memory>
+#include <utility>
+#endif
+#include <cmath>
+#include <cstddef>
+#include <cstring>
+
+
+
+
+
+#ifdef _WIN32
+typedef __int16 int16_t;
+typedef unsigned __int16 uint16_t;
+typedef __int32 int32_t;
+typedef unsigned __int32 uint32_t;
+typedef __int64 int64_t;
+typedef unsigned __int64 uint64_t;
+#include <windows.h>
+#else
+#include <stdint.h>
+#include <unistd.h>
+#endif
+
+#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900
+#include <random>
+#endif
+
+#ifdef _WIN32
+#include <windows.h>
+#elif defined(__APPLE__)
+#include <mach/mach_time.h>
+#else
+#include <time.h>
+#endif
+
+// #if defined(EIGEN_USE_LIBXSMM)
+// #include "libxsmm.h"
+// #endif
+
+#ifdef EIGEN_USE_THREADS
+#include "unsupported/Eigen/CXX11/ThreadPool"
+#endif
+
+
+#include "Eigen/src/Core/util/DisableStupidWarnings.h"
+
+#include "unsupported/Eigen/SpecialFunctions"
+#include "unsupported/Eigen/CXX11/src/util/CXX11Meta.h"
+#include "unsupported/Eigen/CXX11/src/util/MaxSizeVector.h"
+
+
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
+
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
+
+#undef TENSOR_CONTRACTION_DISPATCH
+#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
+ if (this->m_lhs_inner_dim_contiguous && \
+ this->m_rhs_inner_dim_contiguous && \
+ !this->m_rhs_inner_dim_reordered) { \
+ METHOD<true, true, false, ALIGNMENT> ARGS; \
+ } else { \
+ eigen_assert(false && "Unsupported contraction formats"); \
+ }
+
+
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
+#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
+
+#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
+
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
new file mode 100644
index 0000000000..b3615f4658
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -0,0 +1,195 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+
+#include <assert.h>
+#include <stdint.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <tuple>
+#include <type_traits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace multithreaded_ops {
+
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ Eigen::ThreadPool* pool_ = nullptr;
+};
+
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+const Eigen::ThreadPoolDevice& GetThreadPoolDevice() {
+ const int thread_count = 4;
+ static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count);
+ static EigenThreadPoolWrapper* thread_pool_wrapper =
+ new EigenThreadPoolWrapper(tp);
+ static Eigen::ThreadPoolDevice* device =
+ new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count);
+ return *device;
+}
+
+// Shorthands for the types we need when interfacing with the EigenTensor
+// library.
+typedef Eigen::TensorMap<
+ Eigen::Tensor<float, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
+ EigenMatrix;
+typedef Eigen::TensorMap<
+ Eigen::Tensor<const float, 2, Eigen::RowMajor, Eigen::DenseIndex>,
+ Eigen::Aligned>
+ ConstEigenMatrix;
+
+typedef Eigen::TensorMap<
+ Eigen::Tensor<float, 4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
+ EigenTensor;
+typedef Eigen::TensorMap<
+ Eigen::Tensor<const float, 4, Eigen::RowMajor, Eigen::DenseIndex>,
+ Eigen::Aligned>
+ ConstEigenTensor;
+
+// Utility functions we need for the EigenTensor API.
+template <typename Device, typename T>
+struct MatMulConvFunctor {
+ // Computes on device "d": out = in0 * in1, where * is matrix
+ // multiplication.
+ void operator()(
+ const Device& d, EigenMatrix out, ConstEigenMatrix in0,
+ ConstEigenMatrix in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ out.device(d) = in0.contract(in1, dim_pair);
+ }
+};
+
+template <class T>
+class EigenTensorConvFunctor {
+ private:
+ Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) {
+ switch (padding) {
+ case kTfLitePaddingValid:
+ return Eigen::PADDING_VALID;
+ case kTfLitePaddingSame:
+ return Eigen::PADDING_SAME;
+ case kTfLitePaddingUnknown:
+ assert(false); // should never get here.
+ return Eigen::PADDING_VALID;
+ }
+ return Eigen::PADDING_SAME; // Prevent compiler warning about missing
+ // return
+ }
+
+ public:
+ void operator()(const T* input_data, T* im2col_buffer, int input_batches,
+ int input_height, int input_width, int input_depth,
+ const T* filter_data, int filter_height, int filter_width,
+ int filter_count, int stride_rows, int stride_cols,
+ int pad_width, int pad_height, TfLitePadding padding,
+ T* output_data, int output_height, int output_width) {
+ const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice();
+
+ const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
+ stride_rows == 1 && stride_cols == 1);
+ if (is_1x1_kernel) {
+ // For 1x1 kernel, the 2D convolution is reduced to matrix
+ // multiplication.
+ const int conv_width = output_height * output_width;
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
+ EigenMatrix output(output_data, conv_width, filter_count);
+ ConstEigenMatrix input(input_data, conv_width, input_depth);
+ ConstEigenMatrix filter(filter_data, input_depth, filter_count);
+ MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
+ filter, dim_pair);
+ } else if (filter_height == input_height && filter_width == input_width &&
+ pad_width == 0 && pad_height == 0) {
+ // If the input data and filter have the same height/width,
+ // the 2D convolution is reduced to matrix multiplication.
+ const int k = // Length of reduction dimension.
+ filter_width * filter_height * input_depth;
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
+ EigenMatrix output(output_data, 1, filter_count);
+ ConstEigenMatrix input(input_data, 1, k);
+ ConstEigenMatrix filter(filter_data, k, filter_count);
+ MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
+ filter, dim_pair);
+ } else {
+ EigenTensor output(output_data, input_batches, output_height,
+ output_width, filter_count);
+ ConstEigenTensor input(input_data, input_batches, input_height,
+ input_width, input_depth);
+ ConstEigenTensor filter(filter_data, filter_height, filter_width,
+ input_depth, filter_count);
+ output.device(device) =
+ Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
+ TfLitePadding2EigenPadding(padding));
+ }
+ }
+};
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, TfLitePadding padding,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ EigenTensorConvFunctor<float> conv_functor;
+ conv_functor(input_data, im2col_data, batches, input_height, input_width,
+ input_depth, filter_data, filter_height, filter_width,
+ output_depth, stride_height, stride_width, pad_height, pad_width,
+ padding, output_data, output_height, output_width);
+
+ optimized_ops::AddBiasAndEvalActivationFunction(
+ bias_data, bias_dims, output_data, output_dims, output_activation_min,
+ output_activation_max);
+}
+
+} // namespace multithreaded_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
new file mode 100644
index 0000000000..bf0bdfb1fb
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -0,0 +1,337 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
+
+#ifdef USE_NEON
+
+#include <arm_neon.h>
+#define kFloatWeightsPerNeonLane 4
+
+namespace tflite {
+namespace tensor_utils {
+
+void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
+
+ // The arrays used to cache the vector.
+ float32x4_t* vector_cache_float32x4 =
+ new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) *
+ sizeof(float32x4_t)];
+ const int kUnrollSize = 2;
+ for (int b = 0; b < n_batch; b++) {
+ float* result_in_batch = result + b * m_rows * result_stride;
+ const float* vector_in_batch = vector + b * m_cols;
+
+ const float* matrix_ptr0 = matrix;
+ // If there is only 1 row, we don't want to assign an illegal pointer.
+ const float* matrix_ptr1 = nullptr;
+ if (m_rows > 1) {
+ matrix_ptr1 = matrix + m_cols;
+ }
+
+ // Cahce the vector.
+ for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
+ vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
+ }
+
+ // Main matrix by vector multiplication loop, which handles two rows of
+ // matrix by vector multiplication.
+ for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
+ float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
+ for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
+ float32x4_t temp = vector_cache_float32x4[c >> 2];
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
+ float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
+ // Vector multiply-accumulate 4 float
+ acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
+ }
+ // Add the 4 intermediate sum values to get the final dot-prod value for
+ // this column.
+ *result_in_batch +=
+ (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
+ vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ *(result_in_batch + result_stride) +=
+ (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
+ vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
+ for (int c = postamble_start; c < m_cols; c++) {
+ *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ *(result_in_batch + result_stride) +=
+ matrix_ptr1[c] * vector_in_batch[c];
+ }
+ matrix_ptr0 += kUnrollSize * m_cols;
+ matrix_ptr1 += kUnrollSize * m_cols;
+ result_in_batch += kUnrollSize * result_stride;
+ }
+ for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
+ float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
+ float32x4_t temp = vector_cache_float32x4[c >> 2];
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
+ // Vector multiply-accumulate 4 float
+ acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ }
+ // Add the 4 intermediate sum values to get the final dot-prod value for
+ // this column.
+ *result_in_batch +=
+ (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
+ vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ for (int c = postamble_start; c < m_cols; c++) {
+ *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ }
+ matrix_ptr0 += m_cols;
+ result_in_batch += result_stride;
+ }
+ }
+ delete[] vector_cache_float32x4;
+}
+
+void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from vector1 and vector2.
+ float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
+ float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
+ // Vector multiply 4 float
+ float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4);
+ // Save to result array.
+ vst1q_f32(&result[v], mul_32x4);
+ }
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = vector1[v] * vector2[v];
+ }
+}
+
+void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
+ float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
+ float32x4_t acc_32x4 = vld1q_f32(result + v);
+ // Vector multiply-accumulate 4 float
+ acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
+ // Save to result array.
+ vst1q_f32(&result[v], acc_32x4);
+ }
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] += vector1[v] * vector2[v];
+ }
+}
+
+void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ // The arrays used to cache the vector.
+ float32x4_t* vector_cache_float32x4 =
+ new float32x4_t[(v_size / kFloatWeightsPerNeonLane) *
+ sizeof(float32x4_t)];
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v);
+ }
+
+ float* result_ptr = result;
+ const float* batch_vector_ptr = batch_vector;
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vectors.
+ float32x4_t result_f32x4 = vld1q_f32(result_ptr + v);
+ float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v);
+ // Multiply-accumulate.
+ result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4,
+ vector_cache_float32x4[v >> 2]);
+ // Store.
+ vst1q_f32(result_ptr + v, result_f32x4);
+ }
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; v++) {
+ result_ptr[v] += vector[v] * batch_vector_ptr[v];
+ }
+ // Update the pointers.
+ result_ptr += v_size;
+ batch_vector_ptr += v_size;
+ }
+ delete[] vector_cache_float32x4;
+}
+
+void NeonSub1Vector(const float* vector, int v_size, float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ float32x4_t one_f32x4 = vmovq_n_f32(1.0);
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from the current pointers of the input column and
+ // subtract from 1.
+ float32x4_t v_f32x4 = vld1q_f32(vector + v);
+ float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4);
+ // Save to output.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = 1.0f - vector[v];
+ }
+}
+
+void NeonClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ // Replicate abs_limit and -abs_limit in two vectors.
+ const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit);
+ const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit);
+
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vector.
+ float32x4_t v_f32x4 = vld1q_f32(vector + v);
+ // Clip between abs_limit and -abs_limit.
+ float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4);
+ result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4);
+ // Save to output.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ // Postamble loop.
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = (abs_limit < vector[v]) ? abs_limit : vector[v];
+ result[v] = (-abs_limit > result[v]) ? -abs_limit : result[v];
+ }
+}
+
+float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load 4 float values from vector1 and vector2 and accumulator.
+ float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
+ float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
+ // Vector multiply-accumulate 4 float
+ acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
+ }
+
+ float result = (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
+ // Postamble loop.
+ for (int v = postamble_start; v < v_size; v++) {
+ result += vector1[v] * vector2[v];
+ }
+ return result;
+}
+
+void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ float* result_ptr = result;
+ const float* vector1_ptr = vector1;
+ const float* vector2_ptr = vector2;
+ for (int b = 0; b < n_batch; b++) {
+ *result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
+ vector1_ptr += v_size;
+ vector2_ptr += v_size;
+ result_ptr += result_stride;
+ }
+}
+
+void NeonReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ const float* input_vector_ptr = input_vector;
+ for (int o = 0; o < output_size; o++) {
+ // If reduction_size is not divisible by kWeightsPerNeonLane, we cannot use
+ // the main vectorized loop, and we need to process sequentially.
+ // postamble_start shows the start index where this should happen.
+ const int postamble_start =
+ reduction_size - (reduction_size & (kFloatWeightsPerNeonLane - 1));
+ float32x4_t sum_f32x4 = vmovq_n_f32(0.0);
+ for (int r = 0; r < postamble_start; r += kFloatWeightsPerNeonLane) {
+ float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r);
+ sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4);
+ }
+ output_vector[o] +=
+ (vgetq_lane_f32(sum_f32x4, 0) + vgetq_lane_f32(sum_f32x4, 1) +
+ vgetq_lane_f32(sum_f32x4, 2) + vgetq_lane_f32(sum_f32x4, 3));
+ input_vector_ptr += postamble_start;
+
+ // Postamble loop.
+ for (int r = postamble_start; r < reduction_size; r++) {
+ output_vector[o] += *input_vector_ptr++;
+ }
+ }
+}
+
+void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) {
+ // This variable keeps track of the next to the last index which is being
+ // copied to make sure we are not out of the vector boundary.
+ int last_index_copy = kFloatWeightsPerNeonLane;
+ int current_index_copy = 0;
+ while (last_index_copy < v_size) {
+ float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1);
+ vst1q_f32(vector + current_index_copy, v_f32x4);
+ current_index_copy += kFloatWeightsPerNeonLane;
+ last_index_copy += kFloatWeightsPerNeonLane;
+ }
+ // Postamble loop.
+ for (int i = current_index_copy; i < v_size - 1; i++) {
+ vector[i] = vector[i + 1];
+ }
+ vector[v_size - 1] = shift_value;
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // USE_NEON
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
new file mode 100644
index 0000000000..3a4af87304
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+
+// TODO(ghodrat): Remove this header file and the dependency to internal data
+// structure.
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
+ vector, n_batch, result, result_stride);
+}
+
+void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result) {
+ NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
+}
+
+void VectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size,
+ result);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size,
+ batch_vector, n_batch, result);
+}
+
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
+}
+
+void BatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
+ n_batch, result, result_stride);
+}
+
+void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
+}
+
+void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
+ PortableApplySigmoidToVector(vector, v_size, result);
+}
+
+void ApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation, float* result) {
+ PortableApplyActivationToVector(vector, v_size, activation, result);
+}
+
+void CopyVector(const float* vector, int v_size, float* result) {
+ PortableCopyVector(vector, v_size, result);
+}
+
+void Sub1Vector(const float* vector, int v_size, float* result) {
+ NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
+}
+
+void ZeroVector(float* vector, int v_size) {
+ PortableZeroVector(vector, v_size);
+}
+
+float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+
+void ClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
+}
+
+void VectorShiftLeft(float* vector, int v_size, float shift_value) {
+ NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
+}
+
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
new file mode 100644
index 0000000000..cd565c16a1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -0,0 +1,3715 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+
+#include <assert.h>
+#include <stdint.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <tuple>
+#include <type_traits>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Make a local VectorMap typedef allowing to map a float array
+// as a Eigen vector expression. The std::conditional here is to
+// construct the suitable Eigen type for the constness of the
+// data. Indeed, for const data, we need to produce
+// Eigen::Map<const Eigen::Matrix<float, ...>>
+// and not the more straightforward
+// Eigen::Map<Eigen::Matrix<const float, ...>>
+template <typename Scalar>
+using VectorMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, 1>>,
+ Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
+
+template <typename Scalar, int N>
+VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
+ const int size = RequiredBufferSizeForDims(dims);
+ return VectorMap<Scalar>(data, size, 1);
+}
+
+// Make a local VectorMap typedef allowing to map a float array
+// as a Eigen matrix expression. The same explanation as for VectorMap
+// above also applies here.
+template <typename Scalar>
+using MatrixMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, Eigen::Dynamic>>,
+ Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar>
+using ArrayMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, Eigen::Dynamic>>,
+ Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+
+template <typename Scalar, int N>
+ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return ArrayMap<Scalar>(data, rows, cols);
+}
+
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const Dims<N>& dims,
+ int rows) {
+ int cols = 1;
+ bool matched_rows = false;
+ for (int d = 0; d < N; d++) {
+ cols *= dims.sizes[d];
+ if (cols == rows) {
+ matched_rows = true;
+ cols = 1;
+ }
+ }
+ TFLITE_DCHECK(matched_rows);
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// ELEMENT-WISE BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
+ for (int i = 0; i < 4; i++) {
+ if (dims1.sizes[i] != dims2.sizes[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+#ifdef USE_NEON
+ gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
+ const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
+ const int array_size = array_dims.sizes[3] * array_dims.strides[3];
+ TFLITE_DCHECK_EQ((array_size % bias_size), 0);
+ float* array_ptr = array_data;
+ float* array_end_ptr = array_ptr + array_size;
+ const auto activation_min = vdupq_n_f32(output_activation_min);
+ const auto activation_max = vdupq_n_f32(output_activation_max);
+ for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
+ int i = 0;
+ for (; i <= bias_size - 16; i += 16) {
+ auto b0 = vld1q_f32(bias_data + i);
+ auto b1 = vld1q_f32(bias_data + i + 4);
+ auto b2 = vld1q_f32(bias_data + i + 8);
+ auto b3 = vld1q_f32(bias_data + i + 12);
+ auto a0 = vld1q_f32(array_ptr + i);
+ auto a1 = vld1q_f32(array_ptr + i + 4);
+ auto a2 = vld1q_f32(array_ptr + i + 8);
+ auto a3 = vld1q_f32(array_ptr + i + 12);
+ auto x0 = vaddq_f32(a0, b0);
+ auto x1 = vaddq_f32(a1, b1);
+ auto x2 = vaddq_f32(a2, b2);
+ auto x3 = vaddq_f32(a3, b3);
+ x0 = vmaxq_f32(activation_min, x0);
+ x1 = vmaxq_f32(activation_min, x1);
+ x2 = vmaxq_f32(activation_min, x2);
+ x3 = vmaxq_f32(activation_min, x3);
+ x0 = vminq_f32(activation_max, x0);
+ x1 = vminq_f32(activation_max, x1);
+ x2 = vminq_f32(activation_max, x2);
+ x3 = vminq_f32(activation_max, x3);
+ vst1q_f32(array_ptr + i, x0);
+ vst1q_f32(array_ptr + i + 4, x1);
+ vst1q_f32(array_ptr + i + 8, x2);
+ vst1q_f32(array_ptr + i + 12, x3);
+ }
+ for (; i <= bias_size - 4; i += 4) {
+ auto b = vld1q_f32(bias_data + i);
+ auto a = vld1q_f32(array_ptr + i);
+ auto x = vaddq_f32(a, b);
+ x = vmaxq_f32(activation_min, x);
+ x = vminq_f32(activation_max, x);
+ vst1q_f32(array_ptr + i, x);
+ }
+ for (; i < bias_size; i++) {
+ array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
+ output_activation_min,
+ output_activation_max);
+ }
+ }
+#else // not NEON
+ gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
+ const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
+ const int array_size = array_dims.sizes[3] * array_dims.strides[3];
+ TFLITE_DCHECK_EQ((array_size % bias_size), 0);
+ for (int array_offset = 0; array_offset < array_size;
+ array_offset += bias_size) {
+ for (int i = 0; i < bias_size; i++) {
+ array_data[array_offset + i] = ActivationFunctionWithMinMax(
+ array_data[array_offset + i] + bias_data[i], output_activation_min,
+ output_activation_max);
+ }
+ }
+#endif
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
+ output_activation_min,
+ output_activation_max);
+}
+
+template <typename Lhs, typename Rhs, typename Result>
+void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
+ Eigen::MatrixBase<Result>* result) {
+ if (rhs.cols() == 1) {
+ gemmlowp::ScopedProfilingLabel label("GEMV");
+ result->col(0).noalias() = lhs * rhs.col(0);
+ } else {
+ gemmlowp::ScopedProfilingLabel label("GEMM");
+ result->noalias() = lhs * rhs;
+ }
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("FullyConnected");
+ // TODO(b/62193649): this convoluted shape computation (determining
+ // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
+ // is because the current --variable_batch hack consists in overwriting the
+ // 3rd dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ // When that is fixed, this should become:
+ // const auto input_matrix_map =
+ // MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const int input_rows = ArraySize(weights_dims, 0);
+ const auto input_matrix_map =
+ MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
+ const auto filter_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
+ output_dims, output_activation_min,
+ output_activation_max);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void preload_l1_stream(const uint8* ptr) {
+#ifdef GEMMLOWP_ARM_64
+ asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+#else
+ gemmlowp::Prefetch(ptr);
+#endif
+}
+
+#ifdef USE_NEON
+inline void FullyConnectedAsGEMV(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3),
+ 1);
+ const int input_size = input_dims.strides[3];
+ const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ static constexpr int kPeel = 4;
+ for (int k = 0; k < input_size; k += 64) {
+ preload_l1_stream(input_data + k);
+ }
+ for (int k = 0; k < kPeel * input_size; k += 64) {
+ preload_l1_stream(filter_data + k);
+ }
+ TFLITE_DCHECK(!(output_size % kPeel));
+ const int32* bias_ptr = bias_data;
+ uint8* output_ptr = output_data;
+ for (int out = 0; out < output_size; out += kPeel) {
+ int32x4_t acc[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vdupq_n_s32(0);
+ }
+ const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
+ const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
+ int in = 0;
+ for (; in <= input_size - 16; in += 16) {
+ const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
+ uint8x16_t filter_val_u8[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
+ filter_val_u8[k] = vld1q_u8(filter_ptr);
+ preload_l1_stream(filter_ptr + 64);
+ }
+ int16x8_t input_val[2];
+ const uint8x8_t low = vget_low_u8(input_val_u8);
+ const uint8x8_t high = vget_high_u8(input_val_u8);
+ input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low));
+ input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high));
+ input_val[0] = vaddq_s16(input_val[0], input_offset_vec);
+ input_val[1] = vaddq_s16(input_val[1], input_offset_vec);
+ int16x8_t filter_val[kPeel][2];
+ for (int k = 0; k < kPeel; k++) {
+ const uint8x8_t low = vget_low_u8(filter_val_u8[k]);
+ const uint8x8_t high = vget_high_u8(filter_val_u8[k]);
+ filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low));
+ filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high));
+ filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec);
+ filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec);
+ }
+ for (int p = 0; p < 2; p++) {
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]),
+ vget_low_s16(input_val[p]));
+ }
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]),
+ vget_high_s16(input_val[p]));
+ }
+ }
+ }
+ for (; in <= input_size - 8; in += 8) {
+ const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
+ uint8x8_t filter_val_u8[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
+ filter_val_u8[k] = vld1_u8(filter_ptr);
+ }
+ int16x8_t input_val;
+ input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
+ input_val = vaddq_s16(input_val, input_offset_vec);
+ int16x8_t filter_val[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k]));
+ filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec);
+ }
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]),
+ vget_low_s16(input_val));
+ }
+ for (int k = 0; k < kPeel; k++) {
+ acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]),
+ vget_high_s16(input_val));
+ }
+ }
+ if (in < input_size) {
+ int32 buf[4 * kPeel];
+ for (int k = 0; k < 4; k++) {
+ vst1q_s32(buf + 4 * k, acc[k]);
+ }
+ for (; in < input_size; in++) {
+ int lane = (in + 8 - input_size) % 4;
+ const int32 input_val = input_data[in] + input_offset;
+ for (int k = 0; k < kPeel; k++) {
+ int32 filter_val =
+ filter_data[in + (out + k) * input_size] + filter_offset;
+ buf[lane + 4 * k] += filter_val * input_val;
+ }
+ }
+ for (int k = 0; k < 4; k++) {
+ acc[k] = vld1q_s32(buf + 4 * k);
+ }
+ }
+
+ // Horizontally reduce accumulators
+ int32x2_t pairwise_reduced_acc[kPeel];
+ for (int k = 0; k < kPeel; k++) {
+ pairwise_reduced_acc[k] =
+ vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k]));
+ }
+ static_assert(kPeel == 4, "the code below currently assumes kPeel = 4");
+ const int32x2_t reduced_lo =
+ vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]);
+ const int32x2_t reduced_hi =
+ vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]);
+ int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
+ // Add bias values.
+ int32x4_t bias_vec = vld1q_s32(bias_ptr);
+ bias_ptr += 4;
+ reduced = vaddq_s32(reduced, bias_vec);
+ // Multiply by the fixed-point multiplier.
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ // Rounding-shift-right.
+ using gemmlowp::RoundingDivideByPOT;
+ reduced = RoundingDivideByPOT(reduced, output_shift);
+ // Add the output offset.
+ const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
+ reduced = vaddq_s32(reduced, output_offset_vec);
+ // Narrow values down to 16 bit signed.
+ const int16x4_t res16 = vqmovn_s32(reduced);
+ // Narrow values down to 8 bit unsigned, saturating.
+ uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
+ // Apply the clamping from the activation function
+ res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
+ res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
+ // Store results to destination. Assumes 32bit alignment.
+ vst1_lane_u32(reinterpret_cast<uint32*>(output_ptr),
+ vreinterpret_u32_u8(res8), 0);
+ output_ptr += kPeel;
+ }
+}
+#endif // USE_NEON
+
+struct GemmlowpOutputPipeline {
+ typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
+ ColVectorMap;
+ typedef std::tuple<
+ gemmlowp::OutputStageBiasAddition<ColVectorMap>,
+ gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+ gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8>
+ Pipeline;
+ static Pipeline Make(const int32* bias_data, int output_rows,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max) {
+ ColVectorMap bias_vector(bias_data, output_rows);
+ gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
+ bias_addition_stage.bias_vector = bias_vector;
+ gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ quantize_down_stage;
+ quantize_down_stage.result_offset_after_shift = output_offset;
+ quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
+ quantize_down_stage.result_shift = output_shift;
+ gemmlowp::OutputStageClamp clamp_stage;
+ clamp_stage.min = output_activation_min;
+ clamp_stage.max = output_activation_max;
+ gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
+ return std::make_tuple(bias_addition_stage, quantize_down_stage,
+ clamp_stage, saturating_cast_stage);
+ }
+};
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3);
+#ifdef USE_NEON
+ const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ if (batches == 1 && !(output_size % 4)) {
+ return FullyConnectedAsGEMV(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max, output_data,
+ output_dims);
+ }
+#endif // USE_NEON
+ const int filter_rows = filter_dims.sizes[1];
+ const int filter_cols = filter_dims.sizes[0];
+ TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
+ const int output_rows = output_dims.sizes[0];
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, batches, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, batches, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+template <typename T>
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+ gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+ // This chunk of code reshapes all the inputs corresponding to
+ // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
+ const int kwidth_times_indepth = kwidth * in_depth;
+ const int inwidth_times_indepth = in_width * in_depth;
+ const int ih_ungated_start = h * stride_height - pad_height;
+ const int ih_ungated_end = (ih_ungated_start + kheight);
+ const int ih_end = std::min(ih_ungated_end, in_height);
+ const int iw_ungated_start = w * stride_width - pad_width;
+ const int iw_ungated_end = (iw_ungated_start + kwidth);
+ const int iw_end = std::min(iw_ungated_end, in_width);
+ // If the patch is off the edge of the input image, skip writing those rows
+ // and columns from the patch into the output array.
+ const int h_offset = std::max(0, -ih_ungated_start);
+ const int w_offset = std::max(0, -iw_ungated_start);
+ const int ih_start = std::max(0, ih_ungated_start);
+ const int iw_start = std::max(0, iw_ungated_start);
+ const int single_row_num =
+ std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
+ const int output_row_offset = (buffer_id * single_buffer_length);
+ int out_offset =
+ output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
+ int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+
+ // Express all of the calculations as padding around the input patch.
+ const int top_padding = h_offset;
+ const int bottom_padding = (ih_ungated_end - ih_end);
+ const int left_padding = w_offset;
+ const int right_padding = (iw_ungated_end - iw_end);
+ assert(single_row_num ==
+ ((kwidth - (left_padding + right_padding)) * in_depth));
+
+ // Write out zeroes to the elements representing the top rows of the input
+ // patch that are off the edge of the input image.
+ if (top_padding > 0) {
+ const int top_row_elements = (top_padding * kwidth * in_depth);
+ memset(conv_buffer_data + output_row_offset, byte_zero,
+ (top_row_elements * sizeof(T)));
+ }
+
+ // If the patch is on the interior of the input image horizontally, just copy
+ // over the rows sequentially, otherwise add zero padding at the start or end.
+ if ((left_padding == 0) && (right_padding == 0)) {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ memcpy(conv_buffer_data + out_offset, in_data + in_offset,
+ single_row_num * sizeof(T));
+ out_offset += kwidth_times_indepth;
+ in_offset += inwidth_times_indepth;
+ }
+ } else {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ if (left_padding > 0) {
+ const int left_start = (out_offset - (left_padding * in_depth));
+ memset(conv_buffer_data + left_start, byte_zero,
+ (left_padding * in_depth * sizeof(T)));
+ }
+ memcpy(conv_buffer_data + out_offset, in_data + in_offset,
+ single_row_num * sizeof(T));
+ if (right_padding > 0) {
+ const int right_start = (out_offset + single_row_num);
+ memset(conv_buffer_data + right_start, byte_zero,
+ (right_padding * in_depth * sizeof(T)));
+ }
+ out_offset += kwidth_times_indepth;
+ in_offset += inwidth_times_indepth;
+ }
+ }
+
+ // If the bottom of the patch falls off the input image, pad the values
+ // representing those input rows with zeroes.
+ if (bottom_padding > 0) {
+ const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
+ const int bottom_start =
+ output_row_offset +
+ ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
+ memset(conv_buffer_data + bottom_start, byte_zero,
+ (bottom_row_elements * sizeof(T)));
+ }
+}
+
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 byte_zero, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Im2col");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+
+ int buffer_id = 0;
+ // Loop over the output nodes.
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < output_height; ++h) {
+ for (int w = 0; w < output_width; ++w) {
+ ExtractPatchIntoBufferColumn(
+ input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+ pad_width, pad_height, input_width, input_height, input_depth,
+ output_depth, buffer_id, input_data, output_data, byte_zero);
+ ++buffer_id;
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, byte_zero, output_data, output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ (void)im2col_data;
+ (void)im2col_dims;
+ gemmlowp::ScopedProfilingLabel label("Conv");
+
+ const float* gemm_input_data = nullptr;
+ const Dims<4>* gemm_input_dims = nullptr;
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_height, filter_width, 0, im2col_data,
+ im2col_dims);
+ gemm_input_data = im2col_data;
+ gemm_input_dims = &im2col_dims;
+ } else {
+ // TODO(aselle): We need to make sure to not send im2col if it is not
+ // needed.
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ gemm_input_dims = &input_dims;
+ }
+
+ const auto im2col_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
+ output_dims, output_activation_min,
+ output_activation_max);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("Conv/8bit");
+
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ const uint8* gemm_input_data = nullptr;
+ const Dims<4>* gemm_input_dims = nullptr;
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ const int input_zero_point = -input_offset;
+ TFLITE_DCHECK_GE(input_zero_point, 0);
+ TFLITE_DCHECK_LE(input_zero_point, 255);
+ Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_height, filter_width, input_zero_point,
+ im2col_data, im2col_dims);
+ gemm_input_data = im2col_data;
+ gemm_input_dims = &im2col_dims;
+ } else {
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ gemm_input_dims = &input_dims;
+ }
+
+ const int gemm_input_rows = gemm_input_dims->sizes[0];
+ const int gemm_input_cols = gemm_input_dims->sizes[1] *
+ gemm_input_dims->sizes[2] *
+ gemm_input_dims->sizes[3];
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, filter_rows, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ gemm_input_data, gemm_input_rows, gemm_input_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols);
+ const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("DepthToSpace");
+
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int batch_size = ArraySize(output_dims, 3);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = block_size * output_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
+ for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ const T* src = input_ptr;
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ memcpy(output_data, src, stride * sizeof(T));
+ output_data += stride;
+ src += input_depth;
+ }
+ input_ptr += stride;
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, byte_zero, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
+
+ const auto input_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
+ output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ const int input_rows = input_dims.sizes[0];
+ const int input_cols =
+ input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3];
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, output_cols, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::Make(
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+
+ const int input_depth = ArraySize(input_dims, 0);
+ const int batch_size = ArraySize(input_dims, 3);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = block_size * input_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
+ for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ T* dst = output_ptr;
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ memcpy(dst, input_data, stride * sizeof(T));
+ input_data += stride;
+ dst += output_depth;
+ }
+ output_ptr += stride;
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void NonGlobalBatchNormalization(
+ const float* input_data, const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims, const float* multiplier_data,
+ const Dims<4>& multiplier_dims, const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2,
+ offset_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1,
+ offset_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, x, y, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, x, y, 0)] +
+ offset_data[Offset(offset_dims, c, x, y, 0)]);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void GlobalBatchNormalization(const float* input_data,
+ const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims,
+ const float* multiplier_data,
+ const Dims<4>& multiplier_dims,
+ const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, 0, 0, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] +
+ offset_data[Offset(offset_dims, c, 0, 0, 0)]);
+ }
+ }
+ }
+ }
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
+
+ const auto input = MapAsVector(input_data, input_dims);
+ auto output = MapAsVector(output_data, output_dims);
+ output = input.cwiseMax(0.0f);
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 1;
+ const float lower = -1;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 6;
+ const float lower = 0;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("L2Normalization");
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ float squared_l2_norm = 0;
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ squared_l2_norm += val * val;
+ }
+ float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm);
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm;
+ }
+ }
+ }
+ }
+}
+
+inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
+ int* output_shift) {
+ *output_shift = 11;
+ while (input >= (1 << 29)) {
+ input /= 4;
+ ++*output_shift;
+ }
+ TFLITE_DCHECK_GT(input, 0);
+ const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
+ const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
+ *output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+ TFLITE_DCHECK_GE(input, (1 << 27));
+ TFLITE_DCHECK_LT(input, (1 << 29));
+ using gemmlowp::FixedPoint;
+ using gemmlowp::Rescale;
+ using gemmlowp::SaturatingRoundingMultiplyByPOT;
+ // Using 3 integer bits gives us enough room for the internal arithmetic in
+ // this Newton-Raphson iteration.
+ using F3 = FixedPoint<int32, 3>;
+ using F0 = FixedPoint<int32, 0>;
+ const F3 fixedpoint_input = F3::FromRaw(input >> 1);
+ const F3 fixedpoint_half_input =
+ SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
+ const F3 fixedpoint_half_three =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
+ // Newton-Raphson iteration
+ // Naive unoptimized starting guess: x = 1
+ F3 x = F3::One();
+ // Naive unoptimized number of iterations: 5
+ for (int i = 0; i < 5; i++) {
+ const F3 x3 = Rescale<3>(x * x * x);
+ x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
+ }
+ const F0 fixedpoint_half_sqrt_2 =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
+ x = x * fixedpoint_half_sqrt_2;
+ *output_inv_sqrt = x.raw();
+ if (*output_shift < 0) {
+ *output_inv_sqrt <<= -*output_shift;
+ *output_shift = 0;
+ }
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK_EQ(batches, 1);
+ TFLITE_DCHECK_EQ(height, 1);
+ TFLITE_DCHECK_EQ(width, 1);
+ int32 square_l2_norm = 0;
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[i] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[i] - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ output_data[i] = static_cast<uint8>(output_val);
+ }
+}
+
+inline void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add");
+ /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
+ output_dims, 3);
+ /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
+ output_dims, 2);
+ /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
+ output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
+ output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ int i = 0;
+ const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+#ifdef USE_NEON
+ const auto activation_min = vdupq_n_f32(output_activation_min);
+ const auto activation_max = vdupq_n_f32(output_activation_max);
+ for (; i <= size - 16; i += 16) {
+ auto a10 = vld1q_f32(input1_data + i);
+ auto a11 = vld1q_f32(input1_data + i + 4);
+ auto a12 = vld1q_f32(input1_data + i + 8);
+ auto a13 = vld1q_f32(input1_data + i + 12);
+ auto a20 = vld1q_f32(input2_data + i);
+ auto a21 = vld1q_f32(input2_data + i + 4);
+ auto a22 = vld1q_f32(input2_data + i + 8);
+ auto a23 = vld1q_f32(input2_data + i + 12);
+ auto x0 = vaddq_f32(a10, a20);
+ auto x1 = vaddq_f32(a11, a21);
+ auto x2 = vaddq_f32(a12, a22);
+ auto x3 = vaddq_f32(a13, a23);
+ x0 = vmaxq_f32(activation_min, x0);
+ x1 = vmaxq_f32(activation_min, x1);
+ x2 = vmaxq_f32(activation_min, x2);
+ x3 = vmaxq_f32(activation_min, x3);
+ x0 = vminq_f32(activation_max, x0);
+ x1 = vminq_f32(activation_max, x1);
+ x2 = vminq_f32(activation_max, x2);
+ x3 = vminq_f32(activation_max, x3);
+ vst1q_f32(output_data + i, x0);
+ vst1q_f32(output_data + i + 4, x1);
+ vst1q_f32(output_data + i + 8, x2);
+ vst1q_f32(output_data + i + 12, x3);
+ }
+ for (; i <= size - 4; i += 4) {
+ auto a1 = vld1q_f32(input1_data + i);
+ auto a2 = vld1q_f32(input2_data + i);
+ auto x = vaddq_f32(a1, a2);
+ x = vmaxq_f32(activation_min, x);
+ x = vminq_f32(activation_max, x);
+ vst1q_f32(output_data + i, x);
+ }
+#endif // NEON
+
+ for (; i < size; i++) {
+ auto x = input1_data[i] + input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
+ output_activation_max);
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ gemmlowp::ScopedProfilingLabel label("Add/8bit");
+ /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
+ output_dims, 3);
+ /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
+ output_dims, 2);
+ /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
+ output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
+ output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ int i = 0;
+ const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+ TFLITE_DCHECK_GT(input1_offset, -256);
+ TFLITE_DCHECK_GT(input2_offset, -256);
+ TFLITE_DCHECK_LT(input1_offset, 256);
+ TFLITE_DCHECK_LT(input2_offset, 256);
+#ifdef USE_NEON
+ for (; i <= size - 8; i += 8) {
+ const auto input1_val_original = vld1_u8(input1_data + i);
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input1_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input1_val =
+ vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset));
+ const auto input2_val =
+ vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset));
+ const auto input1_val_high = vget_high_s16(input1_val);
+ const auto input1_val_low = vget_low_s16(input1_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
+ const auto input2_val_low = vget_low_s16(input2_val);
+ auto x11 = vmovl_s16(input1_val_low);
+ auto x12 = vmovl_s16(input1_val_high);
+ auto x21 = vmovl_s16(input2_val_low);
+ auto x22 = vmovl_s16(input2_val_high);
+ const auto left_shift_dup = vdupq_n_s32(left_shift);
+ x11 = vshlq_s32(x11, left_shift_dup);
+ x12 = vshlq_s32(x12, left_shift_dup);
+ x21 = vshlq_s32(x21, left_shift_dup);
+ x22 = vshlq_s32(x22, left_shift_dup);
+ x11 = vqrdmulhq_n_s32(x11, input1_multiplier);
+ x12 = vqrdmulhq_n_s32(x12, input1_multiplier);
+ x21 = vqrdmulhq_n_s32(x21, input2_multiplier);
+ x22 = vqrdmulhq_n_s32(x22, input2_multiplier);
+ const auto input1_shift_dup = vdupq_n_s32(-input1_shift);
+ const auto input2_shift_dup = vdupq_n_s32(-input2_shift);
+ x11 = vshlq_s32(x11, input1_shift_dup);
+ x12 = vshlq_s32(x12, input1_shift_dup);
+ x21 = vshlq_s32(x21, input2_shift_dup);
+ x22 = vshlq_s32(x22, input2_shift_dup);
+ auto s1 = vaddq_s32(x11, x21);
+ auto s2 = vaddq_s32(x12, x22);
+ s1 = vqrdmulhq_n_s32(s1, output_multiplier);
+ s2 = vqrdmulhq_n_s32(s2, output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ s1 = RoundingDivideByPOT(s1, output_shift);
+ s2 = RoundingDivideByPOT(s2, output_shift);
+ const auto s1_narrowed = vmovn_s32(s1);
+ const auto s2_narrowed = vmovn_s32(s2);
+ const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
+ vdupq_n_s16(output_offset));
+ vst1_u8(output_data + i, vqmovun_s16(s));
+ }
+#endif // NEON
+
+ for (; i < size; i++) {
+ const int32 input1_val = input1_offset + input1_data[i];
+ const int32 input2_val = input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output = std::min(
+ output_activation_max, std::max(output_activation_min, raw_output));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto input2_map = MapAsVector(input2_data, input2_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ if (AreSameDims(input1_dims, input2_dims)) {
+ output_map.array() = input1_map.array() + input2_map.array();
+ } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() + scalar;
+ } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar + input2_map.array();
+ } else {
+ // Should not come here.
+ TFLITE_DCHECK(false);
+ }
+}
+
+// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, raw_output));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
+ input1_multiplier, input1_shift, input2_data, input2_dims,
+ input2_offset, input2_multiplier, input2_shift, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mul");
+ /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
+ output_dims, 3);
+ /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
+ output_dims, 2);
+ /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
+ output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
+ output_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ int i = 0;
+ const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+#ifdef USE_NEON
+ const auto activation_min = vdupq_n_f32(output_activation_min);
+ const auto activation_max = vdupq_n_f32(output_activation_max);
+ for (; i <= size - 16; i += 16) {
+ auto a10 = vld1q_f32(input1_data + i);
+ auto a11 = vld1q_f32(input1_data + i + 4);
+ auto a12 = vld1q_f32(input1_data + i + 8);
+ auto a13 = vld1q_f32(input1_data + i + 12);
+ auto a20 = vld1q_f32(input2_data + i);
+ auto a21 = vld1q_f32(input2_data + i + 4);
+ auto a22 = vld1q_f32(input2_data + i + 8);
+ auto a23 = vld1q_f32(input2_data + i + 12);
+ auto x0 = vmulq_f32(a10, a20);
+ auto x1 = vmulq_f32(a11, a21);
+ auto x2 = vmulq_f32(a12, a22);
+ auto x3 = vmulq_f32(a13, a23);
+
+ x0 = vmaxq_f32(activation_min, x0);
+ x1 = vmaxq_f32(activation_min, x1);
+ x2 = vmaxq_f32(activation_min, x2);
+ x3 = vmaxq_f32(activation_min, x3);
+ x0 = vminq_f32(activation_max, x0);
+ x1 = vminq_f32(activation_max, x1);
+ x2 = vminq_f32(activation_max, x2);
+ x3 = vminq_f32(activation_max, x3);
+
+ vst1q_f32(output_data + i, x0);
+ vst1q_f32(output_data + i + 4, x1);
+ vst1q_f32(output_data + i + 8, x2);
+ vst1q_f32(output_data + i + 12, x3);
+ }
+ for (; i <= size - 4; i += 4) {
+ auto a1 = vld1q_f32(input1_data + i);
+ auto a2 = vld1q_f32(input2_data + i);
+ auto x = vmulq_f32(a1, a2);
+
+ x = vmaxq_f32(activation_min, x);
+ x = vminq_f32(activation_max, x);
+
+ vst1q_f32(output_data + i, x);
+ }
+#endif // NEON
+
+ for (; i < size; i++) {
+ auto x = input1_data[i] * input2_data[i];
+ output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
+ output_activation_max);
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mul/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto input2_map = MapAsVector(input2_data, input2_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ if (AreSameDims(input1_dims, input2_dims)) {
+ output_map.array() = input1_map.array() * input2_map.array();
+ } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() * scalar;
+ } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar * input2_map.array();
+ } else {
+ // Should not come here.
+ TFLITE_DCHECK(false);
+ }
+}
+
+// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 unclamped_result =
+ output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ input1_val * input2_val, output_multiplier, output_shift);
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, unclamped_result));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Concatenation");
+ int concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ for (int j = 0; j < 4; j++) {
+ if (j != concat_dim) {
+ MatchingArraySize(*input_dims[i], j, output_dims, j);
+ }
+ }
+ concat_size += ArraySize(*input_dims[i], concat_dim);
+ }
+ TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ // for now we dont have a model with a Concatenation
+ // with fused activation function.
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ int outer_size = 1;
+ for (int i = concat_dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size =
+ input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
+ output_data, output_dims);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ gemmlowp::ScopedProfilingLabel label("LstmCell");
+ MatchingArraySize( // batches
+ input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims,
+ 3, output_activ_dims, 3);
+ MatchingArraySize( // height
+ input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims,
+ 2, output_activ_dims, 2);
+ MatchingArraySize( // width
+ input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims,
+ 1, output_activ_dims, 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
+ TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
+ 1);
+ const int intern_activ_depth =
+ MatchingArraySize(weights_dims, 1, bias_dims, 0);
+ TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
+ output_state_dims, 0, output_activ_dims, 0);
+ TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+
+ // Concatenate prev_activ and input data together
+ std::vector<float const*> concat_input_arrays_data;
+ std::vector<Dims<4> const*> concat_input_arrays_dims;
+ concat_input_arrays_data.push_back(input_data);
+ concat_input_arrays_data.push_back(prev_activ_data);
+ concat_input_arrays_dims.push_back(&input_dims);
+ concat_input_arrays_dims.push_back(&prev_activ_dims);
+ Concatenation<FusedActivationFunctionType::kNone, float>(
+ 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
+ concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+
+ // Fully connected
+ FullyConnected<FusedActivationFunctionType::kNone>(
+ concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
+ bias_dims, activ_temp_data, activ_temp_dims);
+
+ // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
+ // operations.
+ ArrayMap<float> activ_temp_map =
+ MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims);
+ auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
+ activ_temp_map.cols());
+ ArrayMap<const float> prev_state_map =
+ MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims);
+ ArrayMap<float> output_state_map =
+ MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims);
+ ArrayMap<float> output_activ_map =
+ MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims);
+
+ // Combined memory state and final output calculation
+ gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
+ output_state_map =
+ input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
+ new_input_sm.tanh() +
+ forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
+ prev_state_map;
+ output_activ_map =
+ output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
+ output_state_map.tanh();
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3);
+ const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2);
+ const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ // for now we dont have a model with a TensorFlowSplit
+ // with fused activation function.
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ const int whb = width * height * batches;
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < whb; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
+ output_dims[i]->sizes[0] * sizeof(Scalar));
+ input_ptr += output_dims[i]->sizes[0];
+ }
+ }
+}
+
+inline int NodeOffset(int b, int h, int w, int height, int width) {
+ return (b * height + h) * width + w;
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("AveragePool");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ // TODO(benoitjacob) make this a proper reference impl without Eigen!
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ // TODO(benoitjacob) get rid of the dynamic memory allocation here!
+ Eigen::VectorXf out_count(out_mat.cols());
+ out_count.setZero();
+ // Prefill the output to 0.
+ out_mat.setZero();
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < input_height; ++h) {
+ for (int w = 0; w < input_width; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ int hpad = h + pad_height;
+ int wpad = w + pad_width;
+ int h_start =
+ (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int h_end = std::min(hpad / stride_height + 1, output_height);
+ int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_end = std::min(wpad / stride_width + 1, output_width);
+ // compute elementwise sum
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
+ out_mat.col(out_offset) +=
+ in_mat.col(NodeOffset(b, h, w, input_height, input_width));
+ out_count(out_offset)++;
+ }
+ }
+ }
+ }
+ }
+ // Divide the output by the actual number of elements being averaged over
+ TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
+ out_mat.array().rowwise() /= out_count.transpose().array();
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ for (int x = 0; x < output_width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ output_data[Offset(output_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ const int filter_count =
+ (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
+ // 1280 required by Inception v3
+ static constexpr int kAccBufferMaxSize = 2048;
+ TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
+ uint16 acc[kAccBufferMaxSize];
+ memset(acc, 0, depth * sizeof(acc[0]));
+ const uint8* input_ptr =
+ input_data + input_dims.strides[1] * in_x_origin +
+ input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ for (int fy = filter_y_start; fy < filter_y_end; fy++) {
+ const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
+ filter_x_start * input_dims.strides[1];
+ for (int fx = filter_x_start; fx < filter_x_end; fx++) {
+ int channel = 0;
+#ifdef USE_NEON
+ for (; channel <= depth - 16; channel += 16) {
+ uint16x8_t acc_reg[2];
+ for (int i = 0; i < 2; i++) {
+ acc_reg[i] = vld1q_u16(acc + channel + 8 * i);
+ }
+ uint8x16_t input_reg = vld1q_u8(input_row_ptr);
+ input_row_ptr += 16;
+ acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg));
+ acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg));
+ for (int i = 0; i < 2; i++) {
+ vst1q_u16(acc + channel + 8 * i, acc_reg[i]);
+ }
+ }
+ for (; channel <= depth - 8; channel += 8) {
+ uint16x8_t acc_reg = vld1q_u16(acc + channel);
+ uint8x8_t input_reg = vld1_u8(input_row_ptr);
+ input_row_ptr += 8;
+ acc_reg = vaddw_u8(acc_reg, input_reg);
+ vst1q_u16(acc + channel, acc_reg);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ acc[channel] += *input_row_ptr++;
+ }
+ }
+ }
+ uint8* output_ptr =
+ output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ int channel = 0;
+#ifdef USE_NEON
+#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
+ if (filter_count == FILTER_COUNT) { \
+ for (; channel <= depth - 8; channel += 8) { \
+ uint16 buf[8]; \
+ for (int i = 0; i < 8; i++) { \
+ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
+ } \
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
+ buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \
+ buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \
+ vst1_u8(output_ptr + channel, buf8); \
+ } \
+ }
+ AVGPOOL_DIVIDING_BY(9)
+ AVGPOOL_DIVIDING_BY(15)
+#undef AVGPOOL_DIVIDING_BY
+ for (; channel <= depth - 8; channel += 8) {
+ uint16 buf[8];
+ for (int i = 0; i < 8; i++) {
+ buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
+ }
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
+ buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
+ buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
+ vst1_u8(output_ptr + channel, buf8);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ uint16 a = (acc[channel] + filter_count / 2) / filter_count;
+ a = std::max<uint16>(a, output_activation_min);
+ a = std::min<uint16>(a, output_activation_max);
+ output_ptr[channel] = static_cast<uint8>(a);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("MaxPool");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ // Prefill the output to minimum representable float value
+ out_mat.setConstant(std::numeric_limits<float>::lowest());
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < input_height; ++h) {
+ for (int w = 0; w < input_width; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ int hpad = h + pad_height;
+ int wpad = w + pad_width;
+ int h_start =
+ (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int h_end = std::min(hpad / stride_height + 1, output_height);
+ int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_end = std::min(wpad / stride_width + 1, output_width);
+ // compute elementwise sum
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
+ out_mat.col(out_offset) =
+ out_mat.col(out_offset)
+ .cwiseMax(in_mat.col(
+ NodeOffset(b, h, w, input_height, input_width)));
+ }
+ }
+ }
+ }
+ }
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ for (int x = 0; x < output_width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ output_data[Offset(output_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ // 2048 required by Inception v3
+ static constexpr int kAccBufferMaxSize = 2048;
+ TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
+ uint8 acc[kAccBufferMaxSize];
+ memset(acc, 0, depth * sizeof(acc[0]));
+ const uint8* input_ptr =
+ input_data + input_dims.strides[1] * in_x_origin +
+ input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ for (int fy = filter_y_start; fy < filter_y_end; fy++) {
+ const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
+ filter_x_start * input_dims.strides[1];
+ for (int fx = filter_x_start; fx < filter_x_end; fx++) {
+ int channel = 0;
+#ifdef USE_NEON
+ for (; channel <= depth - 16; channel += 16) {
+ uint8x16_t acc_reg = vld1q_u8(acc + channel);
+ uint8x16_t input_reg = vld1q_u8(input_row_ptr);
+ input_row_ptr += 16;
+ acc_reg = vmaxq_u8(acc_reg, input_reg);
+ vst1q_u8(acc + channel, acc_reg);
+ }
+
+ for (; channel <= depth - 8; channel += 8) {
+ uint8x8_t acc_reg = vld1_u8(acc + channel);
+ uint8x8_t input_reg = vld1_u8(input_row_ptr);
+ input_row_ptr += 8;
+ acc_reg = vmax_u8(acc_reg, input_reg);
+ vst1_u8(acc + channel, acc_reg);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ acc[channel] = std::max(acc[channel], *input_row_ptr++);
+ }
+ }
+ }
+ uint8* output_ptr =
+ output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ int channel = 0;
+#ifdef USE_NEON
+ for (; channel <= depth - 16; channel += 16) {
+ uint8x16_t a = vld1q_u8(acc + channel);
+ a = vminq_u8(a, vdupq_n_u8(output_activation_max));
+ a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
+ vst1q_u8(output_ptr + channel, a);
+ }
+ for (; channel <= depth - 8; channel += 8) {
+ uint8x8_t a = vld1_u8(acc + channel);
+ a = vmin_u8(a, vdup_n_u8(output_activation_max));
+ a = vmax_u8(a, vdup_n_u8(output_activation_min));
+ vst1_u8(output_ptr + channel, a);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ uint8 a = acc[channel];
+ a = std::max<uint8>(a, output_activation_min);
+ a = std::min<uint8>(a, output_activation_max);
+ output_ptr[channel] = static_cast<uint8>(a);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("L2Pool");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ // Actually carry out L2 Pool. Code is written in forward mode: we go through
+ // the input values once, and write to all the pooled regions that it maps to.
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ Eigen::VectorXf in_square(in_mat.rows());
+ Eigen::VectorXf out_count(out_mat.cols());
+ out_count.setZero();
+ // Prefill the output to 0.
+ out_mat.setZero();
+ for (int b = 0; b < batches; ++b) {
+ for (int h = 0; h < input_height; ++h) {
+ for (int w = 0; w < input_width; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ const int hpad = h + pad_height;
+ const int wpad = w + pad_width;
+ const int h_start = (hpad < filter_height)
+ ? 0
+ : (hpad - filter_height) / stride_height + 1;
+ const int h_end = std::min(hpad / stride_height + 1, output_height);
+ const int w_start = (wpad < filter_width)
+ ? 0
+ : (wpad - filter_width) / stride_width + 1;
+ const int w_end = std::min(wpad / stride_width + 1, output_width);
+ // pre-compute square
+ const int in_offset = w + input_width * (h + input_height * b);
+ in_square =
+ in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
+ // compute elementwise sum of squares
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ const int out_offset = pw + output_width * (ph + output_height * b);
+ out_mat.col(out_offset) += in_square;
+ out_count(out_offset)++;
+ }
+ }
+ }
+ }
+ }
+
+ out_count = out_count.array().inverse();
+ out_mat =
+ (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
+ /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ // Carry out local response normalization, vector by vector.
+ // Since the data are stored column major, making row-wise operation
+ // probably not memory efficient anyway, we do an explicit for loop over
+ // the columns.
+ const int double_range = range * 2;
+ Eigen::VectorXf padded_square(data_in.rows() + double_range);
+ padded_square.setZero();
+ for (int r = 0; r < data_in.cols(); ++r) {
+ // Do local response normalization for data_in(:, r)
+ // first, compute the square and store them in buffer for repeated use
+ padded_square.block(range, 0, data_in.rows(), 1) =
+ data_in.col(r).cwiseProduct(data_in.col(r)) * alpha;
+ // Then, compute the scale and writes them to data_out
+ float accumulated_scale = 0;
+ for (int i = 0; i < double_range; ++i) {
+ accumulated_scale += padded_square(i);
+ }
+ for (int i = 0; i < data_in.rows(); ++i) {
+ accumulated_scale += padded_square(i + double_range);
+ data_out(i, r) = bias + accumulated_scale;
+ accumulated_scale -= padded_square(i);
+ }
+ }
+
+ // In a few cases, the pow computation could benefit from speedups.
+ if (beta == 1) {
+ data_out.array() = data_in.array() * data_out.array().inverse();
+ } else if (beta == 0.5) {
+ data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
+ } else {
+ data_out.array() = data_in.array() * data_out.array().pow(-beta);
+ }
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Softmax");
+ /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ // Compute the exponential first, removing the max coefficient for numerical
+ // stability.
+ out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
+ // We are separating out the exp function so that exp can be vectorized.
+ out_mat = out_mat.array().exp();
+ // Normalize to get the activations.
+ Eigen::Array<float, 1, Eigen::Dynamic> scale =
+ out_mat.array().colwise().sum().inverse();
+ out_mat.array().rowwise() *= scale;
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ gemmlowp::ScopedProfilingLabel label("Softmax");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int x = 0; x < width; ++x) {
+ for (int y = 0; y < height; ++y) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row =
+ std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps =
+ sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead.
+ int headroom_plus_one =
+ __builtin_clz(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::max(std::min(unsat_output, 255), 0);
+
+ } else {
+ output_data[Offset(output_dims, c, x, y, b)] = 0;
+ }
+ }
+ }
+ }
+ }
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Logistic");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() =
+ input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Logistic");
+ /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int size = RequiredBufferSizeForDims(input_dims);
+
+ int c = 0;
+#ifdef USE_NEON
+ // Handle 16 values at a time
+ for (; c <= size - 16; c += 16) {
+ // Read input uint8 values, cast to int16 and subtract input_zero_point
+ uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
+ int16x8_t input_val_centered_0 =
+ vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
+ vdupq_n_s16(input_zero_point));
+ int16x8_t input_val_centered_1 =
+ vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
+ vdupq_n_s16(input_zero_point));
+
+ // Prepare the bit masks that we will use at the end to implement the logic
+ // that was expressed in the scalar code with branching:
+ // if (input_val_centered < -input_range_radius) {
+ // output_val = 0;
+ // } else if (input_val_centered > input_range_radius) {
+ // output_val = 255;
+ // } else {
+ // ...
+ uint16x8_t mask_rightclamp_0 =
+ vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
+ uint16x8_t mask_rightclamp_1 =
+ vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
+ uint16x8_t mask_leftclamp_0 =
+ vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
+ uint16x8_t mask_leftclamp_1 =
+ vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
+ uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
+ vshrn_n_u16(mask_rightclamp_1, 8));
+ uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
+ vshrn_n_u16(mask_leftclamp_1, 8));
+
+ // This performs what is expressed in the scalar code as
+ // const int32 input_val_rescaled =
+ // MultiplyByQuantizedMultiplierGreaterThanOne(
+ // input_val_centered, input_multiplier, input_left_shift);
+ int32x4_t input_val_rescaled_0 =
+ vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_1 =
+ vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_2 =
+ vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_3 =
+ vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
+ vdupq_n_s32(input_left_shift));
+ input_val_rescaled_0 =
+ vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
+ input_val_rescaled_1 =
+ vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
+ input_val_rescaled_2 =
+ vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
+ input_val_rescaled_3 =
+ vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
+
+ // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
+ using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
+ const FixedPoint4 input_val_f4_0 =
+ FixedPoint4::FromRaw(input_val_rescaled_0);
+ const FixedPoint4 input_val_f4_1 =
+ FixedPoint4::FromRaw(input_val_rescaled_1);
+ const FixedPoint4 input_val_f4_2 =
+ FixedPoint4::FromRaw(input_val_rescaled_2);
+ const FixedPoint4 input_val_f4_3 =
+ FixedPoint4::FromRaw(input_val_rescaled_3);
+ const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
+ const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
+ const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
+ const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
+
+ // Divide by 2^23 as in the scalar code
+ using gemmlowp::RoundingDivideByPOT;
+ int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
+ int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
+ int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
+ int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
+
+ // Cast output values to uint8, saturating
+ int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
+ vqmovn_s32(output_val_s32_1));
+ int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
+ vqmovn_s32(output_val_s32_3));
+ uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
+ vqmovun_s16(output_val_s16_1));
+
+ // Perform the bit-masking with the bit masks computed at the beginning,
+ // see the comment there.
+ output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
+ output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
+
+ // Store back to memory
+ vst1q_u8(output_data + c, output_val_u8);
+ }
+#endif
+ // Leftover loop: handle one value at a time with scalar code.
+ for (; c < size; ++c) {
+ const uint8 input_val_u8 = input_data[c];
+ const int32 input_val_centered =
+ static_cast<int32>(input_val_u8) - input_zero_point;
+ uint8 output_val;
+ if (input_val_centered < -input_range_radius) {
+ output_val = 0;
+ } else if (input_val_centered > input_range_radius) {
+ output_val = 255;
+ } else {
+ const int32 input_val_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_val_centered, input_multiplier, input_left_shift);
+ using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
+ const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
+ using gemmlowp::RoundingDivideByPOT;
+ int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
+ if (output_val_s32 == 256) {
+ output_val_s32 = 255;
+ }
+ TFLITE_DCHECK_GE(output_val_s32, 0);
+ TFLITE_DCHECK_LE(output_val_s32, 255);
+ output_val = static_cast<uint8>(output_val_s32);
+ }
+ output_data[c] = output_val;
+ }
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Tanh");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = input_map.array().tanh();
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Dequantize");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int32 val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = static_cast<float>(scale * (val - zero_point));
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("FakeQuant");
+
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ TFLITE_DCHECK_LE(rmin, 0.);
+ TFLITE_DCHECK_GE(rmax, 0.);
+
+ // Determine quantization parameters: zero_point, scale.
+ using Integer = uint8;
+ const Integer qmin = std::numeric_limits<Integer>::min();
+ const Integer qmax = std::numeric_limits<Integer>::max();
+ const float qmin_float = qmin;
+ const float qmax_float = qmax;
+ int32 zero_point = 0;
+ float scale = 0.f;
+ // If rmin==rmax, both must be zero per the above assertion,
+ // so we are done.
+ if (rmin != rmax) {
+ // First determine the scale.
+ scale = (rmax - rmin) / (qmax_float - qmin_float);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const float zero_point_from_min = qmin_float - rmin / scale;
+ const float zero_point_from_max = qmax_float - rmax / scale;
+ const float zero_point_from_min_error =
+ std::abs(qmin_float) + std::abs(rmin / scale);
+ const float zero_point_from_max_error =
+ std::abs(qmax_float) + std::abs(rmax / scale);
+
+ const float zero_point_float =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ if (zero_point_float < qmin_float) {
+ zero_point = qmin;
+ } else if (zero_point_float > qmax_float) {
+ zero_point = qmax;
+ } else {
+ zero_point = static_cast<int32>(TfLiteRound(zero_point_float));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ TFLITE_DCHECK_GE(zero_point, qmin);
+ TFLITE_DCHECK_LE(zero_point, qmax);
+ }
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const float src_val = input_data[Offset(input_dims, c, x, y, b)];
+ const float unclamped_quantized_val =
+ TfLiteRound(zero_point + src_val / scale);
+ const float quantized_val = std::min(
+ qmax_float, std::max(qmin_float, unclamped_quantized_val));
+ const float dst_val = scale * (quantized_val - zero_point);
+ output_data[Offset(output_dims, c, x, y, b)] = dst_val;
+ }
+ }
+ }
+ }
+}
+
+template <typename SrcT, typename DstT>
+inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
+ DstT* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Cast");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = input_map.array().template cast<DstT>();
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Floor");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = Eigen::floor(input_map.array());
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Gather");
+
+ TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
+ int stride = input_dims.strides[input_rank - 1];
+ T* out = output_data;
+
+ for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ TFLITE_DCHECK_GE(coords_data[i], 0);
+ TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ const T* in = input_data + coords_data[i] * stride;
+ memcpy(out, in, sizeof(T) * stride);
+ out += stride;
+ }
+}
+
+#ifdef USE_NEON
+inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
+ float scale, float* output_ptr) {
+ int ic = 0;
+ // Handle 32 input channels at a time.
+ for (; ic <= depth - 32; ic += 32) {
+ float32x4x2_t input[4];
+ for (int i = 0; i < 4; i++) {
+ input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
+ input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
+ }
+ float32x4x2_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
+ acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
+ }
+ for (int i = 0; i < 4; i++) {
+ acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
+ acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
+ }
+ for (int i = 0; i < 4; i++) {
+ vst1q_f32(output_ptr, acc[i].val[0]);
+ vst1q_f32(output_ptr + 4, acc[i].val[1]);
+ output_ptr += 8;
+ }
+ input_ptr += 32;
+ }
+ // Handle 16 input channels at a time.
+ for (; ic <= depth - 16; ic += 16) {
+ float32x4x2_t input[2];
+ for (int i = 0; i < 2; i++) {
+ input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
+ input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
+ }
+ float32x4x2_t acc[2];
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
+ acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
+ }
+ for (int i = 0; i < 2; i++) {
+ acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
+ acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
+ }
+ for (int i = 0; i < 2; i++) {
+ vst1q_f32(output_ptr, acc[i].val[0]);
+ vst1q_f32(output_ptr + 4, acc[i].val[1]);
+ output_ptr += 8;
+ }
+ input_ptr += 16;
+ }
+ // Handle 8 input channels at a time.
+ for (; ic <= depth - 8; ic += 8) {
+ float32x4x2_t input;
+ input.val[0] = vld1q_f32(input_ptr);
+ input.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t acc;
+ acc.val[0] = vld1q_f32(output_ptr);
+ acc.val[1] = vld1q_f32(output_ptr + 4);
+ acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
+ acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
+
+ vst1q_f32(output_ptr, acc.val[0]);
+ vst1q_f32(output_ptr + 4, acc.val[1]);
+
+ input_ptr += 8;
+ output_ptr += 8;
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= depth - 4; ic += 4) {
+ float32x4_t input = vld1q_f32(input_ptr);
+ float32x4_t acc = vld1q_f32(output_ptr);
+
+ acc = vmlaq_n_f32(acc, input, scale);
+ vst1q_f32(output_ptr, acc);
+
+ input_ptr += 4;
+ output_ptr += 4;
+ }
+ // Handle 1 input channel at a time.
+ for (; ic < depth; ic++) {
+ *output_ptr += *input_ptr * scale;
+ output_ptr++;
+ input_ptr++;
+ }
+}
+#else
+inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
+ float scale, float* output_ptr) {
+ for (int32 i = 0; i < depth; i++) {
+ *output_ptr += *input_ptr * scale;
+ output_ptr++;
+ input_ptr++;
+ }
+}
+#endif
+
+inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
+ int32 x, int32 y, int32 depth, int32 batch,
+ const float* input_data,
+ const Dims<4>& input_dims,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ const int32 input_width = ArraySize(input_dims, 1);
+ const int32 output_width = ArraySize(output_dims, 1);
+
+ const int32 input_x_offset = (x1 - x0) * depth;
+ const int32 input_y_offset = (y1 - y0) * depth * input_width;
+ const int32 output_x_offset = depth;
+ const int32 output_y_offset = depth * output_width;
+
+#ifdef USE_NEON
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(x1 >= x0);
+ TFLITE_DCHECK(y1 >= y0);
+
+ int ic = 0;
+ // Handle 8 input channels at a time.
+ for (; ic <= depth - 8; ic += 8) {
+ const float* input_ptr = nullptr;
+
+ float32x4x2_t x0y0;
+ input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ x0y0.val[0] = vld1q_f32(input_ptr);
+ x0y0.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t x1y0;
+ input_ptr += input_x_offset;
+ x1y0.val[0] = vld1q_f32(input_ptr);
+ x1y0.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t x0y1;
+ input_ptr += -input_x_offset + input_y_offset;
+ x0y1.val[0] = vld1q_f32(input_ptr);
+ x0y1.val[1] = vld1q_f32(input_ptr + 4);
+
+ float32x4x2_t x1y1;
+ input_ptr += input_x_offset;
+ x1y1.val[0] = vld1q_f32(input_ptr);
+ x1y1.val[1] = vld1q_f32(input_ptr + 4);
+
+ // Top left corner.
+ float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ vst1q_f32(output_ptr, x0y0.val[0]);
+ vst1q_f32(output_ptr + 4, x0y0.val[1]);
+
+ // Top right corner.
+ output_ptr += output_x_offset;
+ float32x4x2_t tr;
+ tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
+ tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
+ tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
+ tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
+
+ vst1q_f32(output_ptr, tr.val[0]);
+ vst1q_f32(output_ptr + 4, tr.val[1]);
+
+ // Bottom left corner.
+ output_ptr += -output_x_offset + output_y_offset;
+ float32x4x2_t bl;
+ bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
+ bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
+ bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
+ bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
+ vst1q_f32(output_ptr, bl.val[0]);
+ vst1q_f32(output_ptr + 4, bl.val[1]);
+
+ // Bottom right corner.
+ output_ptr += output_x_offset;
+ float32x4x2_t br;
+ br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
+ br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
+ br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
+ br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
+ br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
+ br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
+ vst1q_f32(output_ptr, br.val[0]);
+ vst1q_f32(output_ptr + 4, br.val[1]);
+ }
+ // Handle 4 input channels at a time.
+ for (; ic <= depth - 4; ic += 4) {
+ const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ float32x4_t x0y0 = vld1q_f32(input_ptr);
+ float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
+ float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
+ float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
+
+ // Top left corner.
+ float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ vst1q_f32(output_ptr, x0y0);
+
+ // Top right corner.
+ output_ptr += output_x_offset;
+ float32x4_t tr = vaddq_f32(x0y0, x1y0);
+ tr = vmulq_n_f32(tr, 0.5f);
+ vst1q_f32(output_ptr, tr);
+
+ // Bottom left corner.
+ output_ptr += -output_x_offset + output_y_offset;
+ float32x4_t bl = vaddq_f32(x0y0, x0y1);
+ bl = vmulq_n_f32(bl, 0.5f);
+ vst1q_f32(output_ptr, bl);
+
+ // Bottom right corner.
+ output_ptr += output_x_offset;
+ float32x4_t br = vaddq_f32(x1y0, x1y1);
+ br = vmlaq_n_f32(bl, br, 0.5f);
+ br = vmulq_n_f32(br, 0.5f);
+ vst1q_f32(output_ptr, br);
+ }
+ // Handle one input channel at a time.
+ for (; ic < depth; ic++) {
+ const int32 input_offset = Offset(input_dims, ic, x0, y0, batch);
+
+ float x0y0 = input_data[input_offset];
+ float x1y0 = input_data[input_offset + input_x_offset];
+ float x0y1 = input_data[input_offset + input_y_offset];
+ float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
+
+ // Top left corner.
+ const int32 output_offset = Offset(output_dims, ic, x, y, batch);
+ output_data[output_offset] = x0y0;
+
+ // Top right corner.
+ output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
+
+ // Bottom left corner.
+ float output = (x0y0 + x0y1) / 2;
+ output_data[output_offset + output_y_offset] = output;
+
+ // Bottom right corner.
+ output_data[output_offset + output_x_offset + output_y_offset] =
+ (output + ((x1y0 + x1y1) / 2)) / 2;
+ }
+#else
+ for (int ch = 0; ch < depth; ch++) {
+ const int32 input_offset = Offset(input_dims, ch, x0, y0, batch);
+
+ float x0y0 = input_data[input_offset];
+ float x1y0 = input_data[input_offset + input_x_offset];
+ float x0y1 = input_data[input_offset + input_y_offset];
+ float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
+
+ // Top left corner.
+ const int32 output_offset = Offset(output_dims, ch, x, y, batch);
+ output_data[output_offset] = x0y0;
+
+ // Top right corner.
+ output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
+
+ // Bottom left corner.
+ float output = (x0y0 + x0y1) / 2;
+ output_data[output_offset + output_y_offset] = output;
+
+ // Bottom right corner.
+ output_data[output_offset + output_x_offset + output_y_offset] =
+ (output + ((x1y0 + x1y1) / 2)) / 2;
+ }
+#endif
+}
+
+inline void ResizeBilinear2x2(const float* input_data,
+ const Dims<4>& input_dims, float* output_data,
+ const Dims<4>& output_dims, int32 batches,
+ int32 input_height, int32 input_width,
+ int32 depth, int32 output_height,
+ int32 output_width) {
+ for (int b = 0; b < batches; b++) {
+ for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
+ for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data,
+ input_dims, output_data, output_dims);
+ }
+ }
+ }
+}
+
+inline void ResizeBilinearGeneric(const float* input_data,
+ const Dims<4>& input_dims, float* output_data,
+ const Dims<4>& output_dims, int32 batches,
+ int32 input_height, int32 input_width,
+ int32 depth, int32 output_height,
+ int32 output_width, float height_scale,
+ float width_scale) {
+ memset(output_data, 0,
+ batches * output_height * output_width * depth * sizeof(float));
+
+ int32 output_offset = 0;
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ float input_y = y * height_scale;
+ int32 y0 = static_cast<int32>(std::floor(input_y));
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ for (int x = 0; x < output_width; ++x) {
+ float input_x = x * width_scale;
+ int32 x0 = static_cast<int32>(input_x);
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+ float* output_ptr = &output_data[output_offset];
+
+ // Run kernel on the 4 corners of the bilinear resize algorithm.
+ int32 input_offset = Offset(input_dims, 0, x0, y0, b);
+ float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
+ const float* input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_dims, 0, x1, y0, b);
+ scale = (1 - (input_y - y0)) * (input_x - x0);
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_dims, 0, x0, y1, b);
+ scale = (input_y - y0) * (1 - (input_x - x0));
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_dims, 0, x1, y1, b);
+ scale = (input_y - y0) * (input_x - x0);
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ output_offset += depth;
+ }
+ }
+ }
+}
+
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
+ int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ int32 input_height = ArraySize(input_dims, 2);
+ int32 input_width = ArraySize(input_dims, 1);
+ int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
+ int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+
+ // Specialize for 2x2 upsample.
+ if (output_height == 2 * input_height && output_width == 2 * input_width) {
+ ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches,
+ input_height, input_width, depth, output_height,
+ output_width);
+ } else {
+ float height_scale = static_cast<float>(input_height) / output_height;
+ float width_scale = static_cast<float>(input_width) / output_width;
+
+ ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims,
+ batches, input_height, input_width, depth,
+ output_height, output_width, height_scale,
+ width_scale);
+ }
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("SpaceToBatchND");
+
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_height = block_shape_data[0];
+ const int block_shape_width = block_shape_data[1];
+ const int padding_top = paddings_data[0];
+ const int padding_left = paddings_data[2];
+
+ for (int out_b = 0; out_b < output_batch_size; ++out_b) {
+ int input_batch = out_b % input_batch_size;
+ int shift_w = (out_b / input_batch_size) % block_shape_width;
+ int shift_h = (out_b / input_batch_size) / block_shape_width;
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ if (out_h * block_shape_height < padding_top ||
+ out_h * block_shape_height >= padding_top + input_height ||
+ out_w * block_shape_width < padding_left ||
+ out_w * block_shape_width >= padding_left + input_width) {
+ memset(out, 0, depth * sizeof(T));
+ } else {
+ const T* in =
+ input_data +
+ Offset(input_dims, 0,
+ (out_w * block_shape_width + shift_w) - padding_left,
+ (out_h * block_shape_height + shift_h) - padding_top,
+ input_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
+
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_width = block_shape_data[1];
+ const int block_shape_height = block_shape_data[0];
+
+ for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ int out_batch = in_batch % output_batch_size;
+ int out_w = in_w * block_shape_width +
+ (in_batch / output_batch_size) % block_shape_width;
+ int out_h = in_h * block_shape_height +
+ (in_batch / output_batch_size) / block_shape_width;
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
+ const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Pad");
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int left_b_padding = left_paddings[3];
+ const int left_h_padding = left_paddings[2];
+ const int left_w_padding = left_paddings[1];
+ const int left_d_padding = left_paddings[0];
+
+ const int right_b_padding = right_paddings[3];
+ const int right_h_padding = right_paddings[2];
+ const int right_w_padding = right_paddings[1];
+ const int right_d_padding = right_paddings[0];
+
+ const int input_depth = ArraySize(input_dims, 0);
+
+ if (left_b_padding != 0) {
+ memset(output_data, 0,
+ left_b_padding * output_height * output_width * output_depth *
+ sizeof(T));
+ }
+ for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
+ ++out_b) {
+ if (left_h_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0,
+ left_h_padding * output_width * output_depth * sizeof(T));
+ }
+ for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
+ ++out_h) {
+ if (left_w_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0,
+ left_w_padding * output_depth * sizeof(T));
+ }
+ for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
+ ++out_w) {
+ if (left_d_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0,
+ left_d_padding * sizeof(T));
+ }
+
+ T* out = output_data +
+ Offset(output_dims, left_d_padding, out_w, out_h, out_b);
+ const T* in =
+ input_data + Offset(input_dims, 0, out_w - left_w_padding,
+ out_h - left_h_padding, out_b - left_b_padding);
+ memcpy(out, in, input_depth * sizeof(T));
+
+ if (right_d_padding != 0) {
+ memset(
+ output_data + Offset(output_dims, output_depth - right_d_padding,
+ out_w, out_h, out_b),
+ 0, right_d_padding * sizeof(T));
+ }
+ }
+ if (right_w_padding != 0) {
+ memset(
+ output_data + Offset(output_dims, 0, output_width - right_w_padding,
+ out_h, out_b),
+ 0, right_w_padding * output_depth * sizeof(T));
+ }
+ }
+ if (right_h_padding != 0) {
+ memset(output_data + Offset(output_dims, 0, 0,
+ output_height - right_h_padding, out_b),
+ 0, right_h_padding * output_width * output_depth * sizeof(T));
+ }
+ }
+ if (right_b_padding != 0) {
+ memset(output_data +
+ Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
+ 0,
+ right_b_padding * output_height * output_width * output_depth *
+ sizeof(T));
+ }
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask,
+ const std::vector<int>& starts,
+ const std::vector<int>& stops,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("StridedSlice");
+ const int start_b = (begin_mask & 8) ? 0 : starts[3];
+ const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
+ const int start_h = (begin_mask & 4) ? 0 : starts[2];
+ const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
+ const int start_w = (begin_mask & 2) ? 0 : starts[1];
+ const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
+ const int start_d = (begin_mask & 1) ? 0 : starts[0];
+ const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
+
+ T* out_ptr = output_data;
+ if (strides[0] == 0) {
+ for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
+ for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
+ for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
+ const int len = stop_d - start_d;
+ memcpy(out_ptr,
+ input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
+ len * sizeof(T));
+ out_ptr += len;
+ }
+ }
+ }
+ } else {
+ for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
+ for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
+ for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
+ for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
+ *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ // TODO(dkalenichenko): This op only supports 4D tensors.
+ TFLITE_DCHECK_EQ(begin.size(), 4);
+ TFLITE_DCHECK_EQ(size.size(), 4);
+ const int start_b = begin[3];
+ const int stop_b =
+ size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
+ const int start_h = begin[2];
+ const int stop_h =
+ size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
+ const int start_w = begin[1];
+ const int stop_w =
+ size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
+ const int start_d = begin[0];
+ const int stop_d =
+ size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+
+ T* out_ptr = output_data;
+ for (int in_b = start_b; in_b < stop_b; ++in_b) {
+ for (int in_h = start_h; in_h < stop_h; ++in_h) {
+ for (int in_w = start_w; in_w < stop_w; ++in_w) {
+ const int len = stop_d - start_d;
+ memcpy(out_ptr,
+ input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
+ len * sizeof(T));
+ out_ptr += len;
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mean");
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+
+ // The current implementation only supports simultaneous reduction over
+ // width and height.
+ TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
+ TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
+ (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(output_height, 1);
+ TFLITE_DCHECK_EQ(output_width, 1);
+
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ float value = 0;
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ }
+ }
+ output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ value / (input_width * input_height);
+ }
+ }
+}
+
+template <typename T>
+void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Sub");
+
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto input2_map = MapAsVector(input2_data, input2_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ if (AreSameDims(input1_dims, input2_dims)) {
+ output_map.array() = input1_map.array() - input2_map.array();
+ } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+ auto scalar = input1_data[0];
+ output_map.array() = scalar - input2_map.array();
+ } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+ auto scalar = input2_data[0];
+ output_map.array() = input1_map.array() - scalar;
+ } else {
+ GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims,
+ output_data, output_dims);
+ }
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ auto min_value = input2_data[0];
+ output_map.array() = input1_map.array().min(min_value);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
+ auto input1_map = MapAsVector(input1_data, input1_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ auto max_value = input2_data[0];
+ output_map.array() = input1_map.array().max(max_value);
+}
+} // namespace optimized_ops
+} // namespace tflite
+
+#if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
+#undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
+#pragma GCC diagnostic pop
+#endif
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
new file mode 100644
index 0000000000..f8be99e82f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+
+// TDOD(ghodrat): Remove this header file and the dependency to internal data
+// structure.
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+#ifndef USE_NEON
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#endif // defined(__ARM_NEON__) || defined(__ARM_NEON)
+#endif // USE_NEON
+
+namespace tflite {
+namespace tensor_utils {
+
+// Multiply a matrix by a batch vector, and store results in a batch-size
+// vector.
+void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
+ int m_rows, int m_cols,
+ const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product of two vectors.
+void PortableVectorVectorCwiseProduct(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result);
+
+// Cwise product and accumulate of two vectors. Since it's a MAC operation, the
+// assumption here is that result array is initialized to valid values.
+void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2,
+ int v_size, float* result);
+void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+
+// Dot product of two vectors.
+float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors.
+void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
+// operation, the assumption here is that result array is initialized to valid
+// values.
+void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch,
+ float* result);
+void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void PortableSub1Vector(const float* vector, int v_size, float* result);
+void NeonSub1Vector(const float* vector, int v_size, float* result);
+
+// Clip elements of a vector using a abs_limit value.
+void PortableClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+void NeonClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+
+// Batch vector initialization with another vector.
+void PortableVectorBatchVectorAssign(const float* vector, int v_size,
+ int n_batch, float* batch_vector);
+
+// Apply sigmoid to elements of a vector.
+void PortableApplySigmoidToVector(const float* vector, int v_size,
+ float* result);
+
+// Apply activation function to elements of a vector.
+void PortableApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation,
+ float* result);
+
+// Copy vector to another vector.
+void PortableCopyVector(const float* vector, int v_size, float* result);
+
+// Fill vector with 0.f.
+void PortableZeroVector(float* vector, int v_size);
+
+// Limit a float input f between +abs_limit and -abs_limit.
+float PortableClip(float f, float abs_limit);
+
+// Shift left a vector in place with v_size size.
+void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
+void NeonVectorShiftLeft(float* vector, int v_size, float shift_value);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void PortableReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+void NeonReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
new file mode 100644
index 0000000000..98f2e365c5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <limits>
+
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+
+namespace tflite {
+
+void QuantizeMultiplierSmallerThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* right_shift) {
+ TFLITE_CHECK(double_multiplier >= 0.);
+ TFLITE_CHECK(double_multiplier < 1.);
+ if (double_multiplier == 0.) {
+ *quantized_multiplier = 0;
+ *right_shift = 0;
+ return;
+ }
+ TFLITE_CHECK(double_multiplier > 0.);
+ const double q = std::frexp(double_multiplier, right_shift);
+ *right_shift *= -1;
+
+ auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+ TFLITE_CHECK(q_fixed <= (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ --*right_shift;
+ }
+ TFLITE_CHECK_GE(*right_shift, 0);
+ TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+void QuantizeMultiplierGreaterThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift) {
+ TFLITE_CHECK(double_multiplier > 1.);
+ const double q = std::frexp(double_multiplier, left_shift);
+ auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+ TFLITE_CHECK(q_fixed <= (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ ++*left_shift;
+ }
+ TFLITE_CHECK_GE(*left_shift, 0);
+ TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+void PreprocessSoftmaxScaling(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier, int* left_shift) {
+ // If the overall multiplier (input and beta) is large, then exp() of an
+ // input difference of 1 scaled by this will be large. In other words, we
+ // can cap the multiplier and know that, when it is used, the output will be
+ // (round to) zero wherever the input is not at the maximum value.
+
+ // If the overall scale is less than one, and input_integer_bits=0, then the
+ // result is double equivalent of Q0.31 (actually with more precision). Thus
+ // this generates a Q(input_integer_bits).(31-input_integer_bits)
+ // representation.
+ const double input_beta_real_multiplier = std::min(
+ beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+
+ QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
+ quantized_multiplier, left_shift);
+}
+
+int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+ const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
+ (1ll << (31 - input_integer_bits)) /
+ (1ll << input_left_shift);
+ // Tighten bound using floor. Suppose that we could use the exact value.
+ // After scaling the difference, the result would be at the maximum. Thus we
+ // must ensure that our value has lower magnitude.
+ return static_cast<int>(std::floor(max_input_rescaled));
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
new file mode 100644
index 0000000000..efb7191c8d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_
+#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_
+
+#include <cstdint>
+
+namespace tflite {
+
+// Decompose a double multiplier into a Q0.31 int32 representation of its
+// significand, and shift representation of its exponent.
+//
+// Restricted to the case where the multiplier < 1 (and non-negative).
+void QuantizeMultiplierSmallerThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* right_shift);
+
+// Decompose a double multiplier into a Q0.31 int32 representation of its
+// significand, and shift representation of its exponent.
+//
+// Restricted to the case where the multiplier > 1.
+void QuantizeMultiplierGreaterThanOne(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift);
+
+// This first creates a multiplier in a double equivalent of
+// Q(input_integer_bits).(31-input_integer_bits) representation, with extra
+// precision in the double's fractional bits. It then splits the result into
+// significand and exponent.
+void PreprocessSoftmaxScaling(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier, int* left_shift);
+
+// Calculate the largest input that will result in a within-bounds intermediate
+// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words,
+// it must not overflow before we reduce the value by multiplication by the
+// input multiplier. The negative radius is used as the minimum difference
+// in Softmax.
+int CalculateInputRadius(int input_integer_bits, int input_left_shift);
+
+} // namespace tflite
+
+#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
new file mode 100644
index 0000000000..d6f306e2cb
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+
+using ::testing::Pair;
+
+TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) {
+ auto quantize = [](double d) {
+ int32_t q;
+ int s;
+ QuantizeMultiplierSmallerThanOne(d, &q, &s);
+ return std::pair<int32_t, int>{q, s};
+ };
+
+ EXPECT_DEATH(quantize(-0.1), "");
+ EXPECT_THAT(quantize(0.0), Pair(0, 0));
+ EXPECT_THAT(quantize(0.25), Pair(1073741824, 1));
+
+ // Around 0.5 we can see the change in exponent and how we try hard to
+ // void hitting max int32.
+ EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1));
+ EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0));
+ EXPECT_THAT(quantize(0.50), Pair(1073741824, 0));
+
+ EXPECT_THAT(quantize(0.75), Pair(1610612736, 0));
+ EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0));
+
+ // If we get close enough to 1.0 it crashes and dies in one of two ways:
+ // Either the shift becomes negative or we trigger the 'less-than-one' CHECK.
+ EXPECT_DEATH(quantize(1 - 1e-15), "");
+ EXPECT_DEATH(quantize(1 - 1e-17), "");
+ EXPECT_DEATH(quantize(1.0), "");
+}
+
+TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) {
+ auto quantize = [](double d) {
+ int32_t q;
+ int s;
+ QuantizeMultiplierGreaterThanOne(d, &q, &s);
+ return std::pair<int32_t, int>{q, s};
+ };
+
+ // If we are close enough to 1.0 it crashes.
+ EXPECT_DEATH(quantize(1 + 1e-16), "");
+
+ EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1));
+ EXPECT_THAT(quantize(1.25), Pair(1342177280, 1));
+ EXPECT_THAT(quantize(1.50), Pair(1610612736, 1));
+ EXPECT_THAT(quantize(1.75), Pair(1879048192, 1));
+
+ // Around the powers of two we see the change in exponent. Also,
+ // we try hard to avoid hitting max int32.
+ EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1));
+ EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2));
+ EXPECT_THAT(quantize(2), Pair(1073741824, 2));
+}
+
+TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
+ auto quantize = [](double beta, double scale, int integer_bits) {
+ int32_t q;
+ int s;
+ PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s);
+ return std::pair<int32_t, int>{q, s};
+ };
+
+ // If beta * scale is greater than fits in the number of integer bits, the
+ // result is move near the maximum. Otherwise they quantize as expected.
+ // With 4 integer bits we can represent up to 16.0.
+ EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31));
+ EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31));
+ // But with 5 bits we can go further.
+ EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31));
+ EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31));
+}
+
+TEST(QuantizationUtilTest, CalculateInputRadius) {
+ EXPECT_EQ(CalculateInputRadius(4, 27), 15);
+ EXPECT_EQ(CalculateInputRadius(3, 27), 14);
+ EXPECT_EQ(CalculateInputRadius(3, 28), 7);
+ EXPECT_EQ(CalculateInputRadius(4, 2), 503316480);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
new file mode 100644
index 0000000000..8e0f234545
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int ic = 0; ic < input_depth; ++ic) {
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int oc = m + ic * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ float total = 0.f;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ float input_value =
+ input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ float filter_value = filter_data[Offset(
+ filter_dims, oc, filter_x, filter_y, 0)];
+ total += (input_value * filter_value);
+ }
+ }
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ }
+ output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ ActivationFunctionWithMinMax(total + bias_value,
+ output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+ }
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+} // end namespace reference_ops
+} // end namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
new file mode 100644
index 0000000000..8a80558b32
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+
+#include <algorithm>
+
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int ic = 0; ic < input_depth; ++ic) {
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int oc = m + ic * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ int32 input_val =
+ input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ int32 filter_val = filter_data[Offset(filter_dims, oc,
+ filter_x, filter_y, 0)];
+ acc +=
+ (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ }
+ }
+ if (bias_data) {
+ acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ }
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+ }
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+} // end namespace reference_ops
+} // end namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
new file mode 100644
index 0000000000..c5b0bccc9d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+float PortableClip(float f, float abs_limit) {
+ float result = (abs_limit < f) ? abs_limit : f;
+ result = (-abs_limit > result) ? -abs_limit : result;
+ return result;
+}
+
+void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
+ int m_rows, int m_cols,
+ const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ float* result_in_batch = result;
+ for (int b = 0; b < n_batch; b++) {
+ const float* matrix_ptr = matrix;
+ for (int r = 0; r < m_rows; r++) {
+ const float* vector_in_batch = vector + b * m_cols;
+ for (int c = 0; c < m_cols; c++) {
+ *result_in_batch += *matrix_ptr++ * *vector_in_batch++;
+ }
+ result_in_batch += result_stride;
+ }
+ }
+}
+
+void PortableVectorVectorCwiseProduct(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = *vector1++ * *vector2++;
+ }
+}
+
+float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ float result = 0.0;
+ for (int v = 0; v < v_size; v++) {
+ result += *vector1++ * *vector2++;
+ }
+ return result;
+}
+
+void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ float* result_ptr = result;
+ const float* vector1_ptr = vector1;
+ const float* vector2_ptr = vector2;
+ for (int b = 0; b < n_batch; b++) {
+ *result_ptr =
+ PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
+ vector1_ptr += v_size;
+ vector2_ptr += v_size;
+ result_ptr += result_stride;
+ }
+}
+
+void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2,
+ int v_size, float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ += *vector1++ * *vector2++;
+ }
+}
+
+void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch,
+ float* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ += vector[v] * *batch_vector++;
+ }
+ }
+}
+
+void PortableVectorBatchVectorAssign(const float* vector, int v_size,
+ int n_batch, float* batch_vector) {
+ for (int b = 0; b < n_batch; b++) {
+ memcpy(batch_vector + b * v_size, vector, v_size * sizeof(float));
+ }
+}
+
+void PortableApplySigmoidToVector(const float* vector, int v_size,
+ float* result) {
+ auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid);
+ for (int v = 0; v < v_size; v++) {
+ *result++ = (sigmoid_func)(*vector++);
+ }
+}
+
+void PortableApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation,
+ float* result) {
+ auto activation_func = ActivationFunctor(activation);
+ for (int v = 0; v < v_size; v++) {
+ *result++ = (activation_func)(*vector++);
+ }
+}
+
+void PortableCopyVector(const float* vector, int v_size, float* result) {
+ memcpy(result, vector, v_size * sizeof(float));
+}
+
+void PortableSub1Vector(const float* vector, int v_size, float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = 1.0f - *vector++;
+ }
+}
+
+void PortableZeroVector(float* vector, int v_size) {
+ memset(vector, 0, v_size * sizeof(float));
+}
+
+void PortableClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = PortableClip(*vector++, abs_limit);
+ }
+}
+
+void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) {
+ TF_LITE_ASSERT(v_size > 0);
+ for (int i = 0; i < v_size - 1; i++) {
+ vector[i] = vector[i + 1];
+ }
+ vector[v_size - 1] = shift_value;
+}
+
+void PortableReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ const float* input_vector_ptr = input_vector;
+ for (int o = 0; o < output_size; o++) {
+ for (int r = 0; r < reduction_size; r++) {
+ output_vector[o] += *input_vector_ptr++;
+ }
+ }
+}
+
+} // namespace tensor_utils
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
new file mode 100644
index 0000000000..c2ab78000b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+
+// TDOD(ghodrat): Remove this header file and the dependency to internal data
+// structure.
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+// Limit a float input f betweeen +abs_limit and -abs_limit.
+float PortableClip(float f, float abs_limit);
+
+// Multiply a matrix by a batch vector, and store results in a batch-size
+// vector.
+void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
+ int m_rows, int m_cols,
+ const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product of two vectors.
+void PortableVectorVectorCwiseProduct(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+
+// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the
+// assumption here is that result array is initialized to valid values.
+void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2,
+ int v_size, float* result);
+
+// Dot product of two vectors.
+float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors.
+void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
+// operation, the assumption here is that result array is initialized to valid
+// values.
+void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
+ int v_size,
+ const float* batch_vector,
+ int n_batch,
+ float* result);
+
+// Batch vector initialization with another vector.
+void PortableVectorBatchVectorAssign(const float* vector, int v_size,
+ int n_batch, float* batch_vector);
+
+// Apply sigmoid to elements of a vector.
+void PortableApplySigmoidToVector(const float* vector, int v_size,
+ float* result);
+
+// Apply activation function to elements of a vector.
+void PortableApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation,
+ float* result);
+
+// Copy vector to another vector.
+void PortableCopyVector(const float* vector, int v_size, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void PortableSub1Vector(const float* vector, int v_size, float* result);
+
+// Fill vector with 0.f.
+void PortableZeroVector(float* vector, int v_size);
+
+// Clip elements of a vector using a abs_limit value.
+void PortableClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+
+// Shift left a vector in place with v_size size.
+void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void PortableReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+
+float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride) {
+ PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
+ n_batch, result, result_stride);
+}
+
+void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result) {
+ PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result);
+}
+
+void VectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result) {
+ PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ PortableVectorBatchVectorCwiseProductAccumulate(vector, v_size, batch_vector,
+ n_batch, result);
+}
+
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size) {
+ return PortableVectorVectorDotProduct(vector1, vector2, v_size);
+}
+
+void BatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride) {
+ PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
+ result, result_stride);
+}
+
+void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
+}
+
+void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
+ PortableApplySigmoidToVector(vector, v_size, result);
+}
+
+void ApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation, float* result) {
+ PortableApplyActivationToVector(vector, v_size, activation, result);
+}
+
+void CopyVector(const float* vector, int v_size, float* result) {
+ PortableCopyVector(vector, v_size, result);
+}
+
+void Sub1Vector(const float* vector, int v_size, float* result) {
+ PortableSub1Vector(vector, v_size, result);
+}
+
+void ZeroVector(float* vector, int v_size) {
+ PortableZeroVector(vector, v_size);
+}
+
+void ClipVector(const float* vector, int v_size, float abs_limit,
+ float* result) {
+ PortableClipVector(vector, v_size, abs_limit, result);
+}
+
+void VectorShiftLeft(float* vector, int v_size, float shift_value) {
+ PortableVectorShiftLeft(vector, v_size, shift_value);
+}
+
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size) {
+ PortableReductionSumVector(input_vector, output_vector, output_size,
+ reduction_size);
+}
+
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
new file mode 100644
index 0000000000..b9ca3d5c62
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -0,0 +1,2455 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <type_traits>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "fixedpoint/fixedpoint.h"
+#include "public/gemmlowp.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
+ int32 x, int32 quantized_multiplier, int right_shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
+ int32 x, int32 quantized_multiplier, int left_shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+ quantized_multiplier);
+}
+
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value,
+ "Only unsigned integer types handled.");
+ const T one_in_leading_positive = static_cast<T>(1)
+ << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+}
+
+// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
+// BROADCASTING.
+//
+// NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
+// rectangular array of numbers.
+//
+// NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
+// However, as Dims<N> is to be deprecated, this class exists as an adaptor
+// to enable simple unoptimized implementations of element-wise broadcasting
+// operations.
+template <int N>
+struct NdArrayDesc {
+ // The "extent" of each dimension. Indices along dimension d must be in the
+ // half-open interval [0, extents[d]).
+ int extents[N];
+
+ // The number of *elements* (not bytes) between consecutive indices of each
+ // dimension.
+ int strides[N];
+};
+
+// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
+// ELEMENT-WISE BROADCASTING.
+//
+// Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
+inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
+ int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
+ return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
+ i3 * desc.strides[3];
+}
+
+// Given the dimensions of the operands for an element-wise binary broadcast,
+// adjusts them so that they can be directly iterated over with simple loops.
+// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
+// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
+//
+// This function assumes that the two input shapes are compatible up to
+// broadcasting and the shorter one has already been prepended with 1s to be the
+// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
+// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
+// Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
+// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
+//
+// When two shapes are compatible up to broadcasting, for each dimension d,
+// the input extents are either equal, or one of them is 1.
+//
+// This function performs the following for each dimension d:
+// - If the extents are equal, then do nothing since the loop that walks over
+// both of the input arrays is correct.
+// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
+// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
+// array0 to be referenced *at any index* in dimension d and still access the
+// same slice.
+template <int N>
+inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
+ const Dims<N>& input1_dims,
+ NdArrayDesc<N>* desc0_out,
+ NdArrayDesc<N>* desc1_out) {
+ TFLITE_DCHECK(desc0_out != nullptr);
+ TFLITE_DCHECK(desc1_out != nullptr);
+
+ // Copy dims to desc.
+ for (int i = 0; i < N; ++i) {
+ desc0_out->extents[i] = input0_dims.sizes[i];
+ desc0_out->strides[i] = input0_dims.strides[i];
+ desc1_out->extents[i] = input1_dims.sizes[i];
+ desc1_out->strides[i] = input1_dims.strides[i];
+ }
+
+ // Walk over each dimension. If the extents are equal do nothing.
+ // Otherwise, set the desc with extent 1 to have extent equal to the other and
+ // stride 0.
+ for (int i = 0; i < N; ++i) {
+ const int extent0 = ArraySize(input0_dims, i);
+ const int extent1 = ArraySize(input1_dims, i);
+ if (extent0 != extent1) {
+ if (extent0 == 1) {
+ desc0_out->strides[i] = 0;
+ desc0_out->extents[i] = extent1;
+ } else {
+ TFLITE_DCHECK_EQ(extent1, 1);
+ desc1_out->strides[i] = 0;
+ desc1_out->extents[i] = extent0;
+ }
+ }
+ }
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_dims; // only used in optimized code.
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
+ }
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ float total = 0.f;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ float input_value = input_data[Offset(input_dims, in_channel,
+ in_x, in_y, batch)];
+ float filter_value =
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
+ total += (input_value * filter_value);
+ }
+ }
+ }
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ }
+ output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(total + bias_value,
+ output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_dims; // only used in optimized code.
+ (void)gemm_context; // only used in optimized code.
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth =
+ MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ int32 input_val = input_data[Offset(input_dims, in_channel,
+ in_x, in_y, batch)];
+ int32 filter_val =
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
+ acc +=
+ (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ }
+ }
+ }
+ if (bias_data) {
+ acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ }
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims, im2col_data, im2col_dims, gemm_context);
+}
+
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_batch = ArraySize(input_dims, 3);
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_batch = ArraySize(output_dims, 3);
+
+ TFLITE_DCHECK_EQ(input_width * block_size, output_width);
+ TFLITE_DCHECK_EQ(input_height * block_size, output_height);
+ TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size);
+ TFLITE_DCHECK_EQ(input_batch, output_batch);
+
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ const int in_d =
+ out_d + ((out_h % block_size) * block_size + out_w % block_size) *
+ output_depth;
+ const int in_w = out_w / block_size;
+ const int in_h = out_h / block_size;
+ const int in_b = out_b;
+
+ const int output_index =
+ Offset(output_dims, out_d, out_w, out_h, out_b);
+ const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+
+ output_data[output_index] = input_data[input_index];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ const int input_depth = ArraySize(input_dims, 0);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_batch = ArraySize(input_dims, 3);
+
+ const int output_depth = ArraySize(output_dims, 0);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_batch = ArraySize(output_dims, 3);
+
+ TFLITE_DCHECK_EQ(input_width, output_width * block_size);
+ TFLITE_DCHECK_EQ(input_height, output_height * block_size);
+ TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
+ TFLITE_DCHECK_EQ(input_batch, output_batch);
+
+ for (int in_b = 0; in_b < input_batch; ++in_b) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ for (int in_d = 0; in_d < input_depth; ++in_d) {
+ const int out_d =
+ in_d + ((in_h % block_size) * block_size + in_w % block_size) *
+ input_depth;
+ const int out_w = in_w / block_size;
+ const int out_h = in_h / block_size;
+ const int out_b = in_b;
+
+ const int output_index =
+ Offset(output_dims, out_d, out_w, out_h, out_b);
+ const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+
+ output_data[output_index] = input_data[input_index];
+ }
+ }
+ }
+ }
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3);
+ const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ const int accum_depth = ArraySize(weights_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ float total = 0.f;
+ for (int d = 0; d < accum_depth; ++d) {
+ total += input_data[b * accum_depth + d] *
+ weights_data[out_c * accum_depth + d];
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+ }
+ output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
+ total + bias_value, output_activation_min, output_activation_max);
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3);
+ const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ const int accum_depth = ArraySize(filter_dims, 0);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ int32 acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32 input_val = input_data[b * accum_depth + d];
+ int32 filter_val = filter_data[out_c * accum_depth + d];
+ acc += (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ if (bias_data) {
+ acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+ }
+ acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier,
+ output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+template <FusedActivationFunctionType Ac>
+void NonGlobalBatchNormalization(
+ const float* input_data, const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims, const float* multiplier_data,
+ const Dims<4>& multiplier_dims, const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2,
+ offset_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1,
+ offset_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, x, y, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, x, y, 0)] +
+ offset_data[Offset(offset_dims, c, x, y, 0)]);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void GlobalBatchNormalization(const float* input_data,
+ const Dims<4>& input_dims, const float* mean_data,
+ const Dims<4>& mean_dims,
+ const float* multiplier_data,
+ const Dims<4>& multiplier_dims,
+ const float* offset_data,
+ const Dims<4>& offset_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
+ offset_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ (input_data[Offset(input_dims, c, x, y, b)] -
+ mean_data[Offset(mean_dims, c, 0, 0, 0)]) *
+ multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] +
+ offset_data[Offset(offset_dims, c, 0, 0, 0)]);
+ }
+ }
+ }
+ }
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float lower = 0;
+ float clamped = val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 1;
+ const float lower = -1;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ const float upper = 6;
+ const float lower = 0;
+ float clamped = val > upper ? upper : val < lower ? lower : val;
+ output_data[Offset(output_dims, c, x, y, b)] = clamped;
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ float squared_l2_norm = 0;
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ squared_l2_norm += val * val;
+ }
+ float l2_norm = std::sqrt(squared_l2_norm);
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input_data[Offset(input_dims, c, x, y, b)] / l2_norm;
+ }
+ }
+ }
+ }
+}
+
+inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
+ int* output_shift) {
+ *output_shift = 11;
+ while (input >= (1 << 29)) {
+ input /= 4;
+ ++*output_shift;
+ }
+ TFLITE_DCHECK_GT(input, 0);
+ const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
+ const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
+ *output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+ TFLITE_DCHECK_GE(input, (1 << 27));
+ TFLITE_DCHECK_LT(input, (1 << 29));
+ using gemmlowp::FixedPoint;
+ using gemmlowp::Rescale;
+ using gemmlowp::SaturatingRoundingMultiplyByPOT;
+ // Using 3 integer bits gives us enough room for the internal arithmetic in
+ // this Newton-Raphson iteration.
+ using F3 = FixedPoint<int32, 3>;
+ using F0 = FixedPoint<int32, 0>;
+ const F3 fixedpoint_input = F3::FromRaw(input >> 1);
+ const F3 fixedpoint_half_input =
+ SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
+ const F3 fixedpoint_half_three =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
+ // Newton-Raphson iteration
+ // Naive unoptimized starting guess: x = 1
+ F3 x = F3::One();
+ // Naive unoptimized number of iterations: 5
+ for (int i = 0; i < 5; i++) {
+ const F3 x3 = Rescale<3>(x * x * x);
+ x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
+ }
+ const F0 fixedpoint_half_sqrt_2 =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
+ x = x * fixedpoint_half_sqrt_2;
+ *output_inv_sqrt = x.raw();
+ if (*output_shift < 0) {
+ *output_inv_sqrt <<= -*output_shift;
+ *output_shift = 0;
+ }
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK_EQ(batches, 1);
+ TFLITE_DCHECK_EQ(height, 1);
+ TFLITE_DCHECK_EQ(width, 1);
+ int32 square_l2_norm = 0;
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int i = 0; i < depth; i++) {
+ int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ output_data[Offset(output_dims, i, 0, 0, 0)] =
+ static_cast<uint8>(output_val);
+ }
+}
+
+inline void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[Offset(input1_dims, c, x, y, b)] +
+ input2_data[Offset(input2_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[Offset(input1_dims, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[Offset(input2_dims, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, raw_output));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+template <FusedActivationFunctionType Ac>
+void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 raw_sum = scaled_input1_val + scaled_input2_val;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, output_multiplier, output_shift) +
+ output_offset;
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, raw_output));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
+ input1_multiplier, input1_shift, input2_data, input2_dims,
+ input2_offset, input2_multiplier, input2_shift, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[Offset(input1_dims, c, x, y, b)] *
+ input2_data[Offset(input2_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
+// dimensionality if the runtime code does a single loop over one dimension
+// that handles broadcasting as the base case. The code generator would then
+// generate max(D1, D2) nested for loops.
+template <FusedActivationFunctionType Ac>
+void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest
+ // stride, typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for
+ // the best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest
+ // stride, typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for
+ // the best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 unclamped_result =
+ output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ input1_val * input2_val, output_multiplier, output_shift);
+ const int32 clamped_output =
+ std::min(output_activation_max,
+ std::max(output_activation_min, unclamped_result));
+ output_data[Offset(output_dims, c, x, y, b)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK_GT(inputs_count, 1);
+ int concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ for (int j = 0; j < 4; j++) {
+ if (j != concat_dim) {
+ MatchingArraySize(*input_dims[i], j, output_dims, j);
+ }
+ }
+ concat_size += ArraySize(*input_dims[i], concat_dim);
+ }
+ TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ int outer_size = 1;
+ for (int i = concat_dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size =
+ input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
+ output_data, output_dims);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ const int batches =
+ MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
+ output_state_dims, 3, output_activ_dims, 3);
+ const int height =
+ MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
+ output_state_dims, 2, output_activ_dims, 2);
+ const int width =
+ MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
+ output_state_dims, 1, output_activ_dims, 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int total_input_depth = prev_activ_depth + input_depth;
+ TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
+ TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
+ 1);
+ const int intern_activ_depth =
+ MatchingArraySize(weights_dims, 1, bias_dims, 0);
+ TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ const int output_depth =
+ MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
+ output_state_dims, 0, output_activ_dims, 0);
+ TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+
+ // Concatenate prev_activ and input data together
+ std::vector<float const*> concat_input_arrays_data;
+ std::vector<Dims<4> const*> concat_input_arrays_dims;
+ concat_input_arrays_data.push_back(input_data);
+ concat_input_arrays_data.push_back(prev_activ_data);
+ concat_input_arrays_dims.push_back(&input_dims);
+ concat_input_arrays_dims.push_back(&prev_activ_dims);
+ Concatenation<FusedActivationFunctionType::kNone, float>(
+ 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
+ concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+
+ // Fully connected
+ FullyConnected<FusedActivationFunctionType::kNone>(
+ concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
+ bias_dims, activ_temp_data, activ_temp_dims);
+
+ // Memory state update (the LSTM "guts")
+ for (int b = 0; b < batches; ++b) {
+ for (int w = 0; w < width; ++w) {
+ for (int h = 0; h < height; ++h) {
+ for (int c = 0; c < output_depth; ++c) {
+ const float input_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(
+ activ_temp_dims, 0 * output_depth + c, w, h, b)]));
+ const float new_input = std::tanh(activ_temp_data[Offset(
+ activ_temp_dims, 1 * output_depth + c, w, h, b)]);
+ const float forget_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(
+ activ_temp_dims, 2 * output_depth + c, w, h, b)]));
+ const float output_gate =
+ 1.f /
+ (1.f + std::exp(-activ_temp_data[Offset(
+ activ_temp_dims, 3 * output_depth + c, w, h, b)]));
+ const float new_state =
+ input_gate * new_input +
+ forget_gate *
+ prev_state_data[Offset(prev_state_dims, c, w, h, b)];
+ output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
+ output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
+ output_gate * std::tanh(new_state);
+ }
+ }
+ }
+ }
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3);
+ const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2);
+ const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1);
+ // for now we dont have a model with a TensorFlowSplit
+ // with fused activation function.
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ int in_c = 0;
+ for (int i = 0; i < outputs_count; ++i) {
+ const int depth = ArraySize(*output_dims[i], 0);
+ for (int c = 0; c < depth; ++c) {
+ output_data[i][Offset(*output_dims[i], c, x, y, b)] =
+ input_data[Offset(input_dims, in_c, x, y, b)];
+ in_c++;
+ }
+ }
+ TFLITE_DCHECK(in_c == ArraySize(input_dims, 0));
+ }
+ }
+ }
+}
+
+// TODO(benoitjacob) make this a proper reference impl without Eigen!
+template <typename Scalar>
+using MatrixMap = typename std::conditional<
+ std::is_const<Scalar>::value,
+ Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
+ Eigen::Dynamic, Eigen::Dynamic>>,
+ Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+inline int NodeOffset(int b, int h, int w, int height, int width) {
+ return (b * height + h) * width + w;
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ float total = 0.f;
+ float filter_count = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ total +=
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ filter_count++;
+ }
+ }
+ const float average = total / filter_count;
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(average, output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ int32 acc = 0;
+ int filter_count = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ filter_count++;
+ }
+ }
+ acc = (acc + filter_count / 2) / filter_count;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ float sum_squares = 0.f;
+ int filter_count = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ const float val =
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ sum_squares += val * val;
+ filter_count++;
+ }
+ }
+ const float l2pool_result = std::sqrt(sum_squares / filter_count);
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(l2pool_result, output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ float max = std::numeric_limits<float>::lowest();
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ max = std::max(
+ max,
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ }
+ }
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ ActivationFunctionWithMinMax(max, output_activation_min,
+ output_activation_max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_GE(output_activation_min, 0);
+ TFLITE_DCHECK_LE(output_activation_max, 255);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int channel = 0; channel < depth; ++channel) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Compute the boundaries of the filter region clamped so as to
+ // ensure that the filter window fits in the input array.
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ uint8 max = 0;
+ for (int filter_y = filter_y_start; filter_y < filter_y_end;
+ ++filter_y) {
+ for (int filter_x = filter_x_start; filter_x < filter_x_end;
+ ++filter_x) {
+ const int in_x = in_x_origin + filter_x;
+ const int in_y = in_y_origin + filter_y;
+ max = std::max(
+ max,
+ input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ }
+ }
+ max = std::max<uint8>(max, output_activation_min);
+ max = std::min<uint8>(max, output_activation_max);
+ output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ static_cast<uint8>(max);
+ }
+ }
+ }
+ }
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const int begin_input_c = std::max(0, c - range);
+ const int end_input_c = std::min(depth, c + range);
+ float accum = 0.f;
+ for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
+ const float input_val =
+ input_data[Offset(input_dims, input_c, x, y, b)];
+ accum += input_val * input_val;
+ }
+ const float multiplier = std::pow(bias + alpha * accum, -beta);
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input_data[Offset(input_dims, c, x, y, b)] * multiplier;
+ }
+ }
+ }
+ }
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ // Find max element value which we'll use to ensure numerical stability
+ // taking advantage of the following equality:
+ // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
+ float max = std::numeric_limits<float>::lowest();
+ for (int c = 0; c < depth; ++c) {
+ max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]);
+ }
+
+ // Compute sum.
+ float sum = 0.f;
+ for (int c = 0; c < depth; ++c) {
+ sum += std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) *
+ beta);
+ }
+
+ // Compute result.
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) *
+ beta) /
+ sum;
+ }
+ }
+ }
+ }
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int x = 0; x < width; ++x) {
+ for (int y = 0; y < height; ++y) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row =
+ std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps =
+ sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ int headroom_plus_one =
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[Offset(input_dims, c, x, y, b)]) -
+ max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[Offset(output_dims, c, x, y, b)] = static_cast<uint8>(
+ std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
+
+ } else {
+ output_data[Offset(output_dims, c, x, y, b)] = 0;
+ }
+ }
+ }
+ }
+ }
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = 1.f / (1.f + std::exp(-val));
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)];
+ const int32 input_val_centered =
+ static_cast<int32>(input_val_u8) - input_zero_point;
+ uint8 output_val;
+ if (input_val_centered <= -input_range_radius) {
+ output_val = 0;
+ } else if (input_val_centered >= input_range_radius) {
+ output_val = 255;
+ } else {
+ const int32 input_val_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_val_centered, input_multiplier, input_left_shift);
+ using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ const FixedPoint4 input_val_f4 =
+ FixedPoint4::FromRaw(input_val_rescaled);
+ const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
+ using gemmlowp::RoundingDivideByPOT;
+ int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
+ if (output_val_s32 == 256) {
+ output_val_s32 = 255;
+ }
+ TFLITE_DCHECK_GE(output_val_s32, 0);
+ TFLITE_DCHECK_LE(output_val_s32, 255);
+ output_val = static_cast<uint8>(output_val_s32);
+ }
+ output_data[Offset(output_dims, c, x, y, b)] = output_val;
+ }
+ }
+ }
+ }
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ float val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = std::tanh(val);
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int32 val = input_data[Offset(input_dims, c, x, y, b)];
+ float result = static_cast<float>(scale * (val - zero_point));
+ output_data[Offset(output_dims, c, x, y, b)] = result;
+ }
+ }
+ }
+ }
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, float* output_data,
+ const Dims<4>& output_dims) {
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ TFLITE_DCHECK_LE(rmin, 0.);
+ TFLITE_DCHECK_GE(rmax, 0.);
+
+ // Determine quantization parameters: zero_point, scale.
+ using Integer = uint8;
+ const Integer qmin = std::numeric_limits<Integer>::min();
+ const Integer qmax = std::numeric_limits<Integer>::max();
+ const float qmin_float = qmin;
+ const float qmax_float = qmax;
+ int32 zero_point = 0;
+ float scale = 0.f;
+ // If rmin==rmax, both must be zero per the above assertion,
+ // so we are done.
+ if (rmin != rmax) {
+ // First determine the scale.
+ scale = (rmax - rmin) / (qmax_float - qmin_float);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const float zero_point_from_min = qmin_float - rmin / scale;
+ const float zero_point_from_max = qmax_float - rmax / scale;
+ const float zero_point_from_min_error =
+ std::abs(qmin_float) + std::abs(rmin / scale);
+ const float zero_point_from_max_error =
+ std::abs(qmax_float) + std::abs(rmax / scale);
+
+ const float zero_point_float =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ if (zero_point_float < qmin_float) {
+ zero_point = qmin;
+ } else if (zero_point_float > qmax_float) {
+ zero_point = qmax;
+ } else {
+ zero_point = static_cast<int32>(TfLiteRound(zero_point_float));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ TFLITE_DCHECK_GE(zero_point, qmin);
+ TFLITE_DCHECK_LE(zero_point, qmax);
+ }
+
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ const float src_val = input_data[Offset(input_dims, c, x, y, b)];
+ const float unclamped_quantized_val =
+ TfLiteRound(zero_point + src_val / scale);
+ const float quantized_val = std::min(
+ qmax_float, std::max(qmin_float, unclamped_quantized_val));
+ const float dst_val = scale * (quantized_val - zero_point);
+ output_data[Offset(output_dims, c, x, y, b)] = dst_val;
+ }
+ }
+ }
+ }
+}
+
+template <typename SrcT, typename DstT>
+inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
+ DstT* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int offset = Offset(input_dims, c, x, y, b);
+ output_data[offset] = static_cast<DstT>(input_data[offset]);
+ }
+ }
+ }
+ }
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
+ const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+ const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ int offset = Offset(input_dims, c, x, y, b);
+ output_data[offset] = std::floor(input_data[offset]);
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
+ int stride = input_dims.strides[input_rank - 1];
+ T* out = output_data;
+
+ for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ TFLITE_DCHECK_GE(coords_data[i], 0);
+ TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ const T* in = input_data + coords_data[i] * stride;
+ memcpy(out, in, sizeof(T) * stride);
+ out += stride;
+ }
+}
+
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ int32 input_height = ArraySize(input_dims, 2);
+ int32 input_width = ArraySize(input_dims, 1);
+ int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
+ int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ float height_scale = static_cast<float>(input_height) / output_height;
+ float width_scale = static_cast<float>(input_width) / output_width;
+
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ float input_y = y * height_scale;
+ int32 y0 = static_cast<int32>(std::floor(input_y));
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ for (int x = 0; x < output_width; ++x) {
+ float input_x = x * width_scale;
+ int32 x0 = static_cast<int32>(std::floor(input_x));
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+ for (int c = 0; c < depth; ++c) {
+ float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] *
+ (1 - (input_y - y0)) *
+ (1 - (input_x - x0)) +
+ input_data[Offset(input_dims, c, x0, y1, b)] *
+ (input_y - y0) * (1 - (input_x - x0)) +
+ input_data[Offset(input_dims, c, x1, y0, b)] *
+ (1 - (input_y - y0)) * (input_x - x0) +
+ input_data[Offset(input_dims, c, x1, y1, b)] *
+ (input_y - y0) * (input_x - x0);
+ output_data[Offset(output_dims, c, x, y, b)] = interpolation;
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_height = block_shape_data[0];
+ const int block_shape_width = block_shape_data[1];
+ const int padding_top = paddings_data[0];
+ const int padding_left = paddings_data[2];
+
+ for (int out_b = 0; out_b < output_batch_size; ++out_b) {
+ int input_batch = out_b % input_batch_size;
+ int shift_w = (out_b / input_batch_size) % block_shape_width;
+ int shift_h = (out_b / input_batch_size) / block_shape_width;
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ if (out_h * block_shape_height < padding_top ||
+ out_h * block_shape_height >= padding_top + input_height ||
+ out_w * block_shape_width < padding_left ||
+ out_w * block_shape_width >= padding_left + input_width) {
+ memset(out, 0, depth * sizeof(T));
+ } else {
+ const T* in =
+ input_data +
+ Offset(input_dims, 0,
+ (out_w * block_shape_width + shift_w) - padding_left,
+ (out_h * block_shape_height + shift_h) - padding_top,
+ input_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch_size = ArraySize(output_dims, 3);
+ const int input_batch_size = ArraySize(input_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int depth = ArraySize(input_dims, 0);
+ const int block_shape_width = block_shape_data[1];
+ const int block_shape_height = block_shape_data[0];
+
+ for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ int out_batch = in_batch % output_batch_size;
+ int out_w = in_w * block_shape_width +
+ (in_batch / output_batch_size) % block_shape_width;
+ int out_h = in_h * block_shape_height +
+ (in_batch / output_batch_size) / block_shape_width;
+ T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
+ const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int left_b_padding = left_paddings[3];
+ const int left_h_padding = left_paddings[2];
+ const int left_w_padding = left_paddings[1];
+ const int left_d_padding = left_paddings[0];
+
+ const int right_b_padding = right_paddings[3];
+ const int right_h_padding = right_paddings[2];
+ const int right_w_padding = right_paddings[1];
+ const int right_d_padding = right_paddings[0];
+
+ const T* in_ptr = input_data;
+ T* out_ptr = output_data;
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_h = 0; out_h < output_height; ++out_h) {
+ for (int out_w = 0; out_w < output_width; ++out_w) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ if (out_b < left_b_padding ||
+ out_b >= output_batch - right_b_padding ||
+ out_h < left_h_padding ||
+ out_h >= output_height - right_h_padding ||
+ out_w < left_w_padding ||
+ out_w >= output_width - right_w_padding ||
+ out_d < left_d_padding ||
+ out_d >= output_depth - right_d_padding) {
+ *out_ptr++ = 0;
+ } else {
+ *out_ptr++ = *in_ptr++;
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask,
+ const std::vector<int>& starts,
+ const std::vector<int>& stops,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ const int start_b = (begin_mask & 8) ? 0 : starts[3];
+ const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
+ const int start_h = (begin_mask & 4) ? 0 : starts[2];
+ const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
+ const int start_w = (begin_mask & 2) ? 0 : starts[1];
+ const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
+ const int start_d = (begin_mask & 1) ? 0 : starts[0];
+ const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
+
+ T* out_ptr = output_data;
+ for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
+ for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
+ for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
+ for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
+ *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ // TODO(dkalenichenko): This op only supports 4D tensors.
+ TFLITE_DCHECK_EQ(begin.size(), 4);
+ TFLITE_DCHECK_EQ(size.size(), 4);
+ const int start_b = begin[3];
+ const int stop_b =
+ size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
+ const int start_h = begin[2];
+ const int stop_h =
+ size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
+ const int start_w = begin[1];
+ const int stop_w =
+ size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
+ const int start_d = begin[0];
+ const int stop_d =
+ size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+
+ T* out_ptr = output_data;
+ for (int in_b = start_b; in_b < stop_b; ++in_b) {
+ for (int in_h = start_h; in_h < stop_h; ++in_h) {
+ for (int in_w = start_w; in_w < stop_w; ++in_w) {
+ for (int in_d = start_d; in_d < stop_d; ++in_d) {
+ *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ const int output_batch = ArraySize(output_dims, 3);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ const int output_depth = ArraySize(output_dims, 0);
+
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+
+ // The current implementation only supports simultaneous reduction over
+ // width and height.
+ TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
+ TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
+ (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(output_height, 1);
+ TFLITE_DCHECK_EQ(output_width, 1);
+
+ for (int out_b = 0; out_b < output_batch; ++out_b) {
+ for (int out_d = 0; out_d < output_depth; ++out_d) {
+ float value = 0;
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ }
+ }
+ output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ value / (input_width * input_height);
+ }
+ }
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ int batches = MatchingArraySize(input1_dims, 3, output_dims, 3);
+ int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2);
+ int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1);
+ int depth = MatchingArraySize(input1_dims, 0, output_dims, 0);
+
+ auto min_value = input2_data[0];
+
+ for (int b = 0; b < batches; b++) {
+ for (int y = 0; y < input_height; y++) {
+ for (int x = 0; x < input_width; x++) {
+ for (int c = 0; c < depth; c++) {
+ int offset = Offset(input1_dims, c, x, y, b);
+ output_data[offset] =
+ input1_data[offset] > min_value ? min_value : input1_data[offset];
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ int batches = MatchingArraySize(input1_dims, 3, output_dims, 3);
+ int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2);
+ int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1);
+ int depth = MatchingArraySize(input1_dims, 0, output_dims, 0);
+
+ auto max_value = input2_data[0];
+
+ for (int b = 0; b < batches; b++) {
+ for (int y = 0; y < input_height; y++) {
+ for (int x = 0; x < input_width; x++) {
+ for (int c = 0; c < depth; c++) {
+ int offset = Offset(input1_dims, c, x, y, b);
+ output_data[offset] =
+ input1_data[offset] < max_value ? max_value : input1_data[offset];
+ }
+ }
+ }
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h
new file mode 100644
index 0000000000..38525b0e20
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/round.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+
+#include <cmath>
+
+namespace tflite {
+
+// TODO(aselle): See if we can do this only on jdk. Also mikecase, check
+// if you need this for java host build.
+#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
+template <class T>
+inline float TfLiteRound(const float x) {
+ return ::round(x);
+}
+inline double TfLiteRound(const double x) { return ::round(x); }
+#else
+template <class T>
+inline T TfLiteRound(const T x) {
+ return std::round(x);
+}
+#endif
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
new file mode 100644
index 0000000000..ee4111e041
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? reinterpret_cast<int64_t*>(tensor->data.raw)
+ : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+ return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+ Dims<4> d;
+ for (int i = 0; i < 4; ++i) {
+ int src = size - i - 1;
+ if (src >= 0) {
+ d.sizes[i] = data[src];
+ } else {
+ d.sizes[i] = 1;
+ }
+ }
+ d.strides[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+ }
+ return d;
+}
+
+inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
+ return GetTensorDims(data.data(), data.size());
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return Dims<4>();
+ }
+
+ auto* dims = tensor->dims;
+ return GetTensorDims(dims->data, dims->size);
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
new file mode 100644
index 0000000000..bf2068d320
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(TensorTest, GetTensorDims4D) {
+ Dims<4> d = GetTensorDims({2, 3, 4, 5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+}
+
+TEST(TensorTest, GetTensorDims3D) {
+ Dims<4> d = GetTensorDims({3, 4, 5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+}
+
+TEST(TensorTest, GetTensorDims2D) {
+ Dims<4> d = GetTensorDims({4, 5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+}
+
+TEST(TensorTest, GetTensorDims1D) {
+ Dims<4> d = GetTensorDims({5});
+ EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
+ EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
new file mode 100644
index 0000000000..904a97803a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+
+#ifndef USE_NEON
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#endif // defined(__ARM_NEON__) || defined(__ARM_NEON)
+#endif // USE_NEON
+
+#ifdef USE_NEON
+#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h"
+#else
+#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h"
+#endif // USE_NEON
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
new file mode 100644
index 0000000000..0e69ef5982
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+// Limit a float input f betweeen +abs_limit and -abs_limit.
+float Clip(float f, float abs_limit);
+
+// Multiply a matrix by a batch vector, and store results in a batch-size
+// vector using a stride value provided in result_stride. 'result_stride' shows
+// how the number of elements between consecutive result values. For example
+// result_stride = 1, will cause the output to look like this:
+// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be
+// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows]
+void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
+ int m_cols, const float* vector,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product of two vectors.
+void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
+ int v_size, float* result);
+
+// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the
+// assumption here is that result array is initialized to valid values.
+void VectorVectorCwiseProductAccumulate(const float* vector1,
+ const float* vector2, int v_size,
+ float* result);
+
+// Dot product of two vectors.
+float VectorVectorDotProduct(const float* vector1, const float* vector2,
+ int v_size);
+
+// Dot product of two batch vectors of size n_batch * v_size:
+// vector1 = [x_1_1, x_1_2, ..., x_1_vsize,
+// x_2_1, x_2_2, ..., x_2_vsize,
+// ...
+// x_nbatch_1,..., x_nbatch_vsize]
+// vector2 = [y_1_1, y_1_2, ..., y_1_vsize,
+// y_2_1, y_2_2, ..., y_2_vsize,
+// ...
+// y_nbatch_1,..., y_nbatch_vsize]
+// Then result will be a vector of n_batch size which will be saved with a
+// stride of result_stride in memory starting from 'result':
+// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize,
+// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
+// ...
+// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
+void BatchVectorBatchVectorDotProduct(const float* vector1,
+ const float* vector2, int v_size,
+ int n_batch, float* result,
+ int result_stride);
+
+// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
+// operation, the assumption here is that result array is initialized to valid
+// values.
+void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
+// Batch vector initialization with another vector.
+void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
+// Apply sigmoid to elements of a vector.
+void ApplySigmoidToVector(const float* vector, int v_size, float* result);
+
+// Apply activation function to elements of a vector.
+void ApplyActivationToVector(const float* vector, int v_size,
+ TfLiteFusedActivation activation, float* result);
+
+// Copy vector to another vector.
+void CopyVector(const float* vector, int v_size, float* result);
+
+// Compute "1.0f - elements of vector" (used in CIFG).
+void Sub1Vector(const float* vector, int v_size, float* result);
+
+// Fill vector with 0.f.
+void ZeroVector(float* vector, int v_size);
+
+// Clip elements of a vector using a abs_limit value.
+void ClipVector(const float* vector, int v_size, float abs_limit,
+ float* result);
+
+// Shift left a vector in place with v_size size.
+void VectorShiftLeft(float* vector, int v_size, float shift_value);
+
+// Reduce-sum on a float input vector:
+// input_vector: float pointer to input vector.
+// output_vector: float pointer to vector.
+// output_size: output vector size.
+// reduction_size: number of consecutive elements from input vector which are
+// added to get one element of output.
+void ReductionSumVector(const float* input_vector, float* output_vector,
+ int output_size, int reduction_size);
+} // namespace tensor_utils
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
new file mode 100644
index 0000000000..588f1a428b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -0,0 +1,192 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include <gmock/gmock.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace tensor_utils {
+
+TEST(uKernels, ClipTest) {
+ constexpr int kVectorSize = 10;
+ constexpr float kAbsLimit = 2.0;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ std::vector<float> output(kVectorSize);
+ ClipVector(input, kVectorSize, kAbsLimit, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0})));
+}
+
+TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
+ constexpr int kRow = 3;
+ constexpr int kCol = 4;
+ constexpr int kBatch = 2;
+ static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, //
+ -1.0, -2.0, -3.0, -4.0, //
+ 1.0, -2.0, 3.0, -4.0};
+ static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, //
+ 2.0, -2.0, 2.0, -2.0};
+ std::vector<float> output(kRow * kBatch);
+ std::fill(output.begin(), output.end(), 3.0);
+ MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
+ output.data(), /*result_stride=*/1);
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., //
+ -1., 7., 23.})));
+
+ std::vector<float> output_with_stride2(kRow * kBatch * 2);
+ std::fill(output_with_stride2.begin(), output_with_stride2.end(), 3.0);
+ MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch,
+ output_with_stride2.data(),
+ /*result_stride=*/2);
+ EXPECT_THAT(output_with_stride2,
+ ElementsAreArray(ArrayFloatNear({1., 3., 5., 3., 13., 3., //
+ -1., 3., 7., 3., 23., 3.})));
+}
+
+TEST(uKernels, VectorVectorCwiseProductTest) {
+ constexpr int kVectorSize = 10;
+ static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1,
+ -0.1, 0.1, -0.1, 0.1, -0.1};
+ std::vector<float> output(kVectorSize);
+ VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45})));
+}
+
+TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
+ constexpr int kVectorSize = 10;
+ static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1,
+ -0.1, 0.1, -0.1, 0.1, -0.1};
+ std::vector<float> output(kVectorSize);
+ std::fill(output.begin(), output.end(), 1.0);
+ VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize,
+ output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear(
+ {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
+}
+
+TEST(uKernels, VectorBatchVectorAssignTest) {
+ constexpr int kVectorSize = 5;
+ constexpr int kBatchSize = 3;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize * kBatchSize);
+ VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data());
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
+ {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0,
+ 0.0, -0.5, 1.0, -1.5, 2.0})));
+}
+
+TEST(uKernels, ApplySigmoidToVectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ ApplySigmoidToVector(input, kVectorSize, output.data());
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
+ {0.5, 0.377541, 0.731059, 0.182426, 0.880797})));
+}
+
+TEST(uKernels, ApplyActivationToVectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0})));
+
+ ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data());
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
+ {0.0, -0.462117, 0.761594, -0.905148, 0.964028})));
+}
+
+TEST(uKernels, CopyVectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ CopyVector(input, kVectorSize, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({0.0, -0.5, 1.0, -1.5, 2.0})));
+}
+
+TEST(uKernels, Sub1VectorTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> output(kVectorSize);
+ Sub1Vector(input, kVectorSize, output.data());
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0})));
+}
+
+TEST(uKernels, ZeroVectorTest) {
+ constexpr int kVectorSize = 5;
+ std::vector<float> output(kVectorSize);
+ ZeroVector(output.data(), kVectorSize);
+ EXPECT_THAT(output,
+ ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0})));
+}
+
+TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
+ constexpr int kVectorSize = 5;
+ constexpr int kBatch = 2;
+ static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ -2.5, 3.0, -3.5, 4.0, -4.5};
+ static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1,
+ -0.1, 0.1, -0.1, 0.1, -0.1};
+ std::vector<float> output(kBatch);
+ BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch,
+ output.data(), /*result_stride=*/1);
+ EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75})));
+}
+
+TEST(uKernels, VectorShiftLeftTest) {
+ constexpr int kVectorSize = 5;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0};
+ std::vector<float> result(kVectorSize);
+ VectorShiftLeft(input, kVectorSize, 3.0);
+ result.assign(input, input + kVectorSize);
+ EXPECT_THAT(result,
+ ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0})));
+}
+
+TEST(uKernels, ReductionSumVectorTest) {
+ constexpr int kInputVectorSize = 10;
+ constexpr int kOutputVectorSize1 = 5;
+ constexpr int kReductionSize1 = 2;
+ static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
+ 0.0, -0.5, 1.0, 1.0, 2.0};
+ std::vector<float> result1(kOutputVectorSize1);
+ ReductionSumVector(input, result1.data(), kOutputVectorSize1,
+ kReductionSize1);
+ EXPECT_THAT(result1,
+ ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0})));
+
+ constexpr int kOutputVectorSize2 = 2;
+ constexpr int kReductionSize2 = 5;
+ std::vector<float> result2(kOutputVectorSize2);
+ ReductionSumVector(input, result2.data(), kOutputVectorSize2,
+ kReductionSize2);
+ EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
+}
+
+} // namespace tensor_utils
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
new file mode 100644
index 0000000000..07f1cb4004
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -0,0 +1,81 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+
+namespace tflite {
+
+enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
+
+template <int N>
+struct Dims {
+ int sizes[N];
+ int strides[N];
+};
+
+inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
+ return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
+ i3 * dims.strides[3];
+}
+
+// Get array size, DCHECKing that the dim index is in range.
+template <int N>
+int ArraySize(const Dims<N>& array, int index) {
+ TFLITE_DCHECK(index >= 0 && index < N);
+ return array.sizes[index];
+}
+
+// Get common array size, DCHECKing that they all agree.
+template <typename ArrayType1, typename ArrayType2>
+int MatchingArraySize(const ArrayType1& array1, int index1,
+ const ArrayType2& array2, int index2) {
+ TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
+ return ArraySize(array1, index1);
+}
+
+template <typename ArrayType1, typename ArrayType2, typename... Args>
+int MatchingArraySize(const ArrayType1& array1, int index1,
+ const ArrayType2& array2, int index2, Args... args) {
+ TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
+ return MatchingArraySize(array1, index1, args...);
+}
+
+inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
+ int max_offset = 0;
+ for (int i = 0; i < 4; i++) {
+ max_offset += (dims.sizes[i] - 1) * dims.strides[i];
+ }
+ return max_offset + 1;
+}
+
+template <int N>
+bool IsPackedWithoutStrides(const Dims<N>& dims) {
+ int expected_stride = 1;
+ for (int d = 0; d < N; d++) {
+ if (dims.strides[d] != expected_stride) return false;
+ expected_stride *= dims.sizes[d];
+ }
+ return true;
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
new file mode 100644
index 0000000000..b0546c00cf
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+
+namespace tflite {
+
+TfLiteStatus GetQuantizedConvolutionMultipler(
+ TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) {
+ const double input_product_scale = input->params.scale * filter->params.scale;
+ const double bias_scale = bias->params.scale;
+ const double output_scale = output->params.scale;
+
+ // TODO(ahentz): The following conditions must be guaranteed by the training
+ // pipeline.
+ TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <=
+ 1e-6 * std::min(input_product_scale, bias_scale));
+ TF_LITE_ENSURE(context, input_product_scale >= 0);
+ TF_LITE_ENSURE(context, input_product_scale < output_scale);
+
+ *multiplier = input_product_scale / output_scale;
+
+ return kTfLiteOk;
+}
+
+void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
+ TfLiteTensor* output, int32_t* act_min,
+ int32_t* act_max) {
+ const int32_t qmin = std::numeric_limits<uint8_t>::min();
+ const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+ const auto scale = output->params.scale;
+ const auto zero_point = output->params.zero_point;
+
+ auto quantize = [scale, zero_point](float f) {
+ return zero_point + static_cast<int32_t>(TfLiteRound(f / scale));
+ };
+
+ if (activation == kTfLiteActRelu) {
+ *act_min = std::max(qmin, quantize(0.0));
+ *act_max = qmax;
+ } else if (activation == kTfLiteActRelu6) {
+ *act_min = std::max(qmin, quantize(0.0));
+ *act_max = std::min(qmax, quantize(6.0));
+ } else if (activation == kTfLiteActRelu1) {
+ *act_min = std::max(qmin, quantize(-1.0));
+ *act_max = std::min(qmax, quantize(1.0));
+ } else {
+ *act_min = qmin;
+ *act_max = qmax;
+ }
+}
+
+void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
+ float* activation_min,
+ float* activation_max) {
+ if (activation == kTfLiteActRelu) {
+ *activation_min = 0.f;
+ *activation_max = std::numeric_limits<float>::max();
+ } else if (activation == kTfLiteActRelu6) {
+ *activation_min = 0.f;
+ *activation_max = 6.f;
+ } else if (activation == kTfLiteActRelu1) {
+ *activation_min = -1.f;
+ *activation_max = 1.f;
+ } else {
+ *activation_min = std::numeric_limits<float>::lowest();
+ *activation_max = std::numeric_limits<float>::max();
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
new file mode 100644
index 0000000000..25556ae456
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
+inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
+ return t->dims->data[dim];
+}
+inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ return &context->tensors[node->inputs->data[index]];
+}
+inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ return &context->tensors[node->outputs->data[index]];
+}
+inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
+inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
+
+inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
+ const TfLiteNode* node, int index) {
+ const bool use_tensor = node->inputs->data[index] != kOptionalTensor;
+ if (use_tensor) {
+ return &context->tensors[node->inputs->data[index]];
+ }
+ return nullptr;
+}
+
+// Calculates the multiplication factor for a quantized convolution (or
+// quantized depthwise convolution) involving the given tensors. Returns an
+// error if the scales of the tensors are not compatible.
+TfLiteStatus GetQuantizedConvolutionMultipler(
+ TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* output, double* multiplier);
+
+// Calculates the useful range of an activation layer given its activation
+// tensor.
+void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
+ TfLiteTensor* output, int32_t* act_min,
+ int32_t* act_max);
+void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
+ float* activation_min,
+ float* activation_max);
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
new file mode 100644
index 0000000000..f43aa372b6
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -0,0 +1,112 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace l2norm {
+
+// This file has two implementation of L2Norm.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(ahentz): Our current implementations rely on the inputs being 4D.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ // TODO(ahentz): Our current implementations only support float32.
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ // TODO(ahentz): For some reason our implementations don't support
+ // activations.
+ TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = input->dims->data[1];
+ output_size->data[2] = input->dims->data[2];
+ output_size->data[3] = input->dims->data[3];
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+#define TF_LITE_L2NORM(type) \
+ type::L2Normalization<FusedActivationFunctionType::kNone>( \
+ GetTensorData<float>(input), GetTensorDims(input), \
+ GetTensorData<float>(output), GetTensorDims(output))
+
+ if (kernel_type == kReference) {
+ TF_LITE_L2NORM(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized) {
+ TF_LITE_L2NORM(optimized_ops);
+ }
+#undef TF_LITE_L2NORM
+ } else {
+ context->ReportError(context, "Inputs and outputs not all float types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace l2norm
+
+TfLiteRegistration* Register_L2NORM_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
+ l2norm::Eval<l2norm::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2NORM_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
+ l2norm::Eval<l2norm::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2_NORMALIZATION() {
+ return Register_L2NORM_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
new file mode 100644
index 0000000000..b1db89b8bd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class L2NormOpModel : public SingleOpModel {
+ public:
+ L2NormOpModel(std::initializer_list<int> input_shape,
+ ActivationFunctionType activation_type) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
+ CreateL2NormOptions(builder_, activation_type).Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(L2NormOpTest, SimpleTest) {
+ L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
new file mode 100644
index 0000000000..c1c70d0dfa
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace local_response_norm {
+
+// This file has two implementation of LocalResponseNorm.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = input->dims->data[1];
+ output_size->data[2] = input->dims->data[2];
+ output_size->data[3] = input->dims->data[3];
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ type::LocalResponseNormalization( \
+ GetTensorData<float>(input), GetTensorDims(input), params->radius, \
+ params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
+ GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized) {
+ TF_LITE_LOCAL_RESPONSE_NORM(optimized_ops);
+ }
+#undef TF_LITE_LOCAL_RESPONSE_NORM
+ } else {
+ context->ReportError(context, "Inputs and outputs not all float types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace local_response_norm
+
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, local_response_norm::Prepare,
+ local_response_norm::Eval<local_response_norm::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, local_response_norm::Prepare,
+ local_response_norm::Eval<local_response_norm::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION() {
+ return Register_LOCAL_RESPONSE_NORM_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc
new file mode 100644
index 0000000000..63a8b0a3d0
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LocalResponseNormOpModel : public SingleOpModel {
+ public:
+ LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
+ float bias, float alpha, float beta) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ BuiltinOptions_LocalResponseNormalizationOptions,
+ CreateLocalResponseNormalizationOptions(builder_, radius, bias,
+ alpha, beta)
+ .Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(LocalResponseNormOpTest, SameAsL2Norm) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
+ /*alpha=*/1.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ // The result is every input divided by 2.
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})));
+}
+
+TEST(LocalResponseNormOpTest, WithAlpha) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
+ /*alpha=*/4.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ // The result is every input divided by 3.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025})));
+}
+
+TEST(LocalResponseNormOpTest, WithBias) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
+ /*alpha=*/4.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ // The result is every input divided by 5.
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02})));
+}
+
+TEST(LocalResponseNormOpTest, SmallRadius) {
+ LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
+ /*alpha=*/4.0, /*beta=*/0.5);
+ m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266})));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
new file mode 100644
index 0000000000..5f73b56ed9
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// LSH Projection projects an input to a bit vector via locality senstive
+// hashing.
+//
+// Options:
+// Sparse:
+// Computed bit vector is considered to be sparse.
+// Each output element is an int32 made up by multiple bits computed from
+// hash functions.
+//
+// Dense:
+// Computed bit vector is considered to be dense. Each output element is
+// either 0 or 1 that represents a bit.
+//
+// Input:
+// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float.
+// Tensor[0].Dim[0]: Num of hash functions.
+// Tensor[0].Dim[1]: Num of projected output bits generated by
+// each hash function.
+// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32.
+//
+// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType.
+// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float.
+// If not set, each element of input is considered to have same
+// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0]
+//
+// Output:
+// Sparse:
+// Output.Dim == { Tensor[0].Dim[0] }
+// A tensor of int32 that represents hash signatures,
+//
+// NOTE: To avoid collisions across hash functions, an offset value of
+// k * (1 << Tensor[0].Dim[1]) will be added to each signature,
+// k is the index of the hash function.
+// Dense:
+// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
+// A flattened tensor represents projected bit vectors.
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <limits>
+#include <memory>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include <farmhash.h>
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lsh_projection {
+
+TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
+ TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* hash = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
+ // Support up to 32 bits.
+ TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
+
+ TfLiteTensor* input = GetInput(context, node, 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
+
+ if (NumInputs(node) == 3) {
+ TfLiteTensor* weight = GetInput(context, node, 2);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
+ SizeOfDimension(input, 0));
+ }
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
+ switch (params->type) {
+ case kTfLiteLshProjectionSparse:
+ outputSize->data[0] = SizeOfDimension(hash, 0);
+ break;
+ case kTfLiteLshProjectionDense:
+ outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1);
+ break;
+ default:
+ return kTfLiteError;
+ }
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+// Compute sign bit of dot product of hash(seed, input) and weight.
+// NOTE: use float as seed, and convert it to double as a temporary solution
+// to match the trained model. This is going to be changed once the new
+// model is trained in an optimized method.
+//
+int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight,
+ float seed) {
+ double score = 0.0;
+ int input_item_bytes = input->bytes / SizeOfDimension(input, 0);
+ char* input_ptr = input->data.raw;
+
+ const size_t seed_size = sizeof(float);
+ const size_t key_bytes = sizeof(float) + input_item_bytes;
+ std::unique_ptr<char[]> key(new char[key_bytes]);
+
+ for (int i = 0; i < SizeOfDimension(input, 0); ++i) {
+ // Create running hash id and value for current dimension.
+ memcpy(key.get(), &seed, seed_size);
+ memcpy(key.get() + seed_size, input_ptr, input_item_bytes);
+
+ int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes);
+ double running_value = static_cast<double>(hash_signature);
+ input_ptr += input_item_bytes;
+ if (weight == nullptr) {
+ score += running_value;
+ } else {
+ score += weight->data.f[i] * running_value;
+ }
+ }
+
+ return (score > 0) ? 1 : 0;
+}
+
+void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
+ const TfLiteTensor* weight, int32_t* out_buf) {
+ int num_hash = SizeOfDimension(hash, 0);
+ int num_bits = SizeOfDimension(hash, 1);
+ for (int i = 0; i < num_hash; i++) {
+ int32_t hash_signature = 0;
+ for (int j = 0; j < num_bits; j++) {
+ float seed = hash->data.f[i * num_bits + j];
+ int bit = RunningSignBit(input, weight, seed);
+ hash_signature = (hash_signature << 1) | bit;
+ }
+ *out_buf++ = hash_signature + i * (1 << num_bits);
+ }
+}
+
+void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
+ const TfLiteTensor* weight, int32_t* out_buf) {
+ int num_hash = SizeOfDimension(hash, 0);
+ int num_bits = SizeOfDimension(hash, 1);
+ for (int i = 0; i < num_hash; i++) {
+ for (int j = 0; j < num_bits; j++) {
+ float seed = hash->data.f[i * num_bits + j];
+ int bit = RunningSignBit(input, weight, seed);
+ *out_buf++ = bit;
+ }
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
+
+ int32_t* out_buf = GetOutput(context, node, 0)->data.i32;
+ TfLiteTensor* hash = GetInput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 1);
+ TfLiteTensor* weight =
+ NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);
+
+ switch (params->type) {
+ case kTfLiteLshProjectionDense:
+ DenseLshProjection(hash, input, weight, out_buf);
+ break;
+ case kTfLiteLshProjectionSparse:
+ SparseLshProjection(hash, input, weight, out_buf);
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+} // namespace lsh_projection
+
+TfLiteRegistration* Register_LSH_PROJECTION() {
+ static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize,
+ lsh_projection::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc
new file mode 100644
index 0000000000..1011927848
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class LSHProjectionOpModel : public SingleOpModel {
+ public:
+ LSHProjectionOpModel(LSHProjectionType type,
+ std::initializer_list<int> hash_shape,
+ std::initializer_list<int> input_shape,
+ std::initializer_list<int> weight_shape) {
+ hash_ = AddInput(TensorType_FLOAT32);
+ input_ = AddInput(TensorType_INT32);
+ if (weight_shape.size() > 0) {
+ weight_ = AddInput(TensorType_FLOAT32);
+ }
+ output_ = AddOutput(TensorType_INT32);
+
+ SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
+ BuiltinOptions_LSHProjectionOptions,
+ CreateLSHProjectionOptions(builder_, type).Union());
+ if (weight_shape.size() > 0) {
+ BuildInterpreter({hash_shape, input_shape, weight_shape});
+ } else {
+ BuildInterpreter({hash_shape, input_shape});
+ }
+
+ output_size_ = 1;
+ for (int i : hash_shape) {
+ output_size_ *= i;
+ if (type == LSHProjectionType_SPARSE) {
+ break;
+ }
+ }
+ }
+ void SetInput(std::initializer_list<int> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetHash(std::initializer_list<float> data) {
+ PopulateTensor(hash_, data);
+ }
+
+ void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }
+
+ std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
+
+ private:
+ int input_;
+ int hash_;
+ int weight_;
+ int output_;
+
+ int output_size_;
+};
+
+TEST(LSHProjectionOpTest2, Dense1DInputs) {
+ LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});
+
+ m.SetInput({12345, 54321, 67890, 9876, -12345678});
+ m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
+ m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
+}
+
+TEST(LSHProjectionOpTest2, Sparse1DInputs) {
+ LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});
+
+ m.SetInput({12345, 54321, 67890, 9876, -12345678});
+ m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
+}
+
+TEST(LSHProjectionOpTest2, Sparse3DInputs) {
+ LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});
+
+ m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
+ 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
+ m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
+ m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
new file mode 100644
index 0000000000..6c06264d84
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -0,0 +1,515 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm {
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 12; // Optional
+constexpr int kForgetGateBiasTensor = 13;
+constexpr int kCellGateBiasTensor = 14;
+constexpr int kOutputGateBiasTensor = 15;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 16; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 17; // Optional
+
+// Output tensors.
+constexpr int kScratchBufferTensor = 0;
+constexpr int kOutputStateTensor = 1;
+constexpr int kCellStateTensor = 2;
+constexpr int kOutputTensor = 3;
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, params->cell_clip >= 0);
+ TF_LITE_ENSURE(context, params->proj_clip >= 0);
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ // TODO(ghodrat): make sure this is correct.
+ const bool projecton_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state and scratch tensors based on the sizes of the input
+// tensors. Also check that the size of the input tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+
+ // Inferring batch size, number of outputs and number of cells from the
+ // input tensors.
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+
+ TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+
+ // Get the pointer to output, state and scratch buffer tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ // TODO(ghodrat): Modify this as soon as we have a finalized method for
+ // scratch buffers.
+ TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+
+ // Resize the output and output_state tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+ output_size->data[0] = n_batch;
+ output_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
+ output_state_size->data[0] = n_batch;
+ output_state_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output_state, output_state_size));
+
+ // Resize the output, state and scratch buffer tensors.
+ TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
+ cell_size->data[0] = n_batch;
+ cell_size->data[1] = n_cell;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state, cell_size));
+
+ // Mark state tensors as persistent tensors.
+ output_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ if (use_cifg) {
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+ } else {
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+ }
+ return kTfLiteOk;
+}
+
+// The LSTM Op engine.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell,
+ n_batch, input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell,
+ n_batch, output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
+ cell_state->data.f, n_batch * n_cell,
+ cell_state->data.f);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell,
+ cell_state->data.f);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell,
+ params->cell_clip, cell_state->data.f);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights != nullptr);
+ const bool use_projection_bias = (projection_bias != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output,
+ n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights->data.f, n_output, n_cell, output_gate_scratch,
+ n_batch, output->data.f, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output->data.f, n_batch * n_output,
+ params->proj_clip, output->data.f);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output->data.f);
+ }
+ tensor_utils::CopyVector(output->data.f, n_batch * n_output,
+ output_state->data.f);
+
+ return kTfLiteOk;
+}
+
+} // namespace lstm
+
+TfLiteRegistration* Register_LSTM() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ lstm::Prepare, lstm::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
new file mode 100644
index 0000000000..be4c7ddbf8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -0,0 +1,1088 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite LSTM op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LSTMOpModel : public SingleOpModel {
+ public:
+ LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
+ bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+ cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(TensorType_FLOAT32);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ scratch_buffer_ = AddOutput(TensorType_FLOAT32);
+ // TODO(ghodrat): Modify these states when we have a permanent solution for
+ // persistent buffer.
+ output_state_ = AddOutput(TensorType_FLOAT32);
+ cell_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
+ CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
+ cell_clip, proj_clip)
+ .Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void ResetOutputState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(output_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void ResetCellState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(cell_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ private:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_;
+ int output_state_;
+ int cell_state_;
+ int scratch_buffer_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
+ -0.15358765, -0.03716109, 0.12507336,
+ 0.41193449, -0.20860538, -0.15053082,
+ 0.09120187, 0.24278517, -0.12222792};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size =
+ sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output + i * lstm.num_outputs();
+ float* golden_end = golden_start + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
+ -0.05163646, -0.42312205, -0.01218222,
+ 0.24201041, -0.08124574, -0.358325,
+ -0.04621704, 0.21641694, -0.06471302};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size =
+ sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output + i * lstm.num_outputs();
+ float* golden_end = golden_start + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(
+ {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
+
+ lstm.SetInputToForgetWeights(
+ {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
+ -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
+ -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
+ 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
+ -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
+ -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
+ 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
+ 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
+ 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
+ -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
+ -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
+ -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
+ 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
+ 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
+ -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
+ 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
+
+ lstm.SetInputToCellWeights(
+ {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
+
+ lstm.SetInputToOutputWeights(
+ {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
+
+ lstm.SetInputGateBias(
+ {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
+ -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
+ -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
+ 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
+
+ lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739});
+
+ lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027});
+
+ lstm.SetOutputGateBias(
+ {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
+ 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
+ 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
+ -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
+
+ lstm.SetRecurrentToOutputWeights({
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
+ -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
+ -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
+ -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
+ -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
+ -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
+ 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
+ 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
+ -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
+ -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
+ 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
+ -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
+ 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
+ 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
+ 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
+ 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
+ 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
+ -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
+ 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
+ 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
+ -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
+ -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
+ -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
+ -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
+ -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
+ 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
+ -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
+ 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
+ -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
+ -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
+ 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
+ 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
+ -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
+ 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
+ -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
+ -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
+ -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
+ -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
+ 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
+ -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
+ -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
+ -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
+ 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
+ -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
+ 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
+ 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
+ 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
+ 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
+ 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ });
+
+ lstm.SetCellToInputWeights(
+ {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
+
+ lstm.SetCellToForgetWeights(
+ {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
+
+ lstm.SetCellToOutputWeights(
+ {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
+
+ lstm.SetProjectionWeights(
+ {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
+ 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
+ -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
+ -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
+ 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
+ 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
+ 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
+ -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
+ -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
+ 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
+ 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
+ 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
+ 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
+ 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
+ -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
+ 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
+ -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
+ 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
+ -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
+ -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
+ -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
+ 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
+ -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
+ -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
+ 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
+ -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
+ 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
+ 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
+ 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
+ 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
+ -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
+ 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
+ -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
+ -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
+ 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
+ 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
+ -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
+ -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
+ 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
+ -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
+ 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
+ -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
+ -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
+ 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
+ -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
+ -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
+ 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
+ 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
+ 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
+
+ static float lstm_input[][20] = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
+ 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
+ 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
+ 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
+ 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
+
+ static float lstm_golden_output[][64] = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ const int input_sequence_size =
+ sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs());
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
+ float* batch1_end = batch1_start + lstm.num_inputs();
+ lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end);
+
+ lstm.Invoke();
+
+ float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
+ float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
+ float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
+ float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
+ expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
new file mode 100644
index 0000000000..81c73f2523
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -0,0 +1,167 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace mul {
+
+// This file has three implementation of Mul.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
+ for (int i = 0; i < NumDimensions(input1); ++i) {
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
+ SizeOfDimension(input2, i));
+ }
+
+ TF_LITE_ENSURE_EQ(context, input1->type, output->type);
+ TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(params->activation, &output_activation_min,
+ &output_activation_max);
+#define TF_LITE_MUL(type) \
+ type::Mul(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops);
+ } else {
+ TF_LITE_MUL(optimized_ops);
+ }
+#undef TF_LITE_MUL
+}
+
+template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, TfLiteTensor* input1,
+ TfLiteTensor* input2, TfLiteTensor* output) {
+ auto input1_offset = -input1->params.zero_point;
+ auto input2_offset = -input2->params.zero_point;
+ auto output_offset = output->params.zero_point;
+
+ int32_t output_multiplier;
+ int output_shift;
+
+ double real_multiplier =
+ input1->params.scale * input2->params.scale / output->params.scale;
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
+ &output_shift);
+
+ int32 output_activation_min, output_activation_max;
+ CalculateActivationRangeUint8(params->activation, output,
+ &output_activation_min, &output_activation_max);
+
+#define TF_LITE_MUL(type) \
+ type::BroadcastMul(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ input1_offset, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, output_offset, \
+ output_multiplier, output_shift, output_activation_min, \
+ output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops);
+ } else {
+ TF_LITE_MUL(optimized_ops);
+ }
+#undef TF_LITE_MUL
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (output->type == kTfLiteFloat32) {
+ EvalFloat<kernel_type>(context, node, params, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8) {
+ EvalQuantized<kernel_type>(context, node, params, input1, input2, output);
+ } else {
+ context->ReportError(context,
+ "Mul only supports FLOAT32 and quantized UINT8 now.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace mul
+
+TfLiteRegistration* Register_MUL_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ mul::Eval<mul::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MUL_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ mul::Eval<mul::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MUL_NEON_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ mul::Eval<mul::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MUL() {
+#ifdef USE_NEON
+ return Register_MUL_NEON_OPT();
+#else
+ return Register_MUL_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
new file mode 100644
index 0000000000..4b858e1f39
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseMulOpModel : public SingleOpModel {
+ public:
+ BaseMulOpModel(TensorData input, TensorData output,
+ ActivationFunctionType activation_type) {
+ input1_ = AddInput(input);
+ input2_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
+ CreateMulOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+class FloatMulOpModel : public BaseMulOpModel {
+ public:
+ using BaseMulOpModel::BaseMulOpModel;
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+// For quantized Mul, the error shouldn't exceed (2*step + step^2).
+// The param min=-1.0 & max=1.0 is used in the following tests.
+// The tolerance value is ~0.0157.
+const float kQuantizedStep = 2.0 / 255.0;
+const float kQuantizedTolerance =
+ 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep;
+
+class QuantizedMulOpModel : public BaseMulOpModel {
+ public:
+ using BaseMulOpModel::BaseMulOpModel;
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatMulOpTest, NoActivation) {
+ FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
+}
+
+TEST(FloatMulOpTest, ActivationRELU1) {
+ FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0})));
+}
+
+TEST(FloatMulOpTest, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2})))
+ << "With shape number " << i;
+ }
+}
+
+TEST(QuantizedMulOpTest, NoActivation) {
+ QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedTolerance)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
new file mode 100644
index 0000000000..7535afaf8e
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+
+#define TF_LITE_FATAL(msg) \
+ do { \
+ fprintf(stderr, "%s\n", (msg)); \
+ exit(1); \
+ } while (0)
+#define TF_LITE_ASSERT(x) \
+ do { \
+ if (!(x)) TF_LITE_FATAL(#x); \
+ } while (0)
+#define TF_LITE_ASSERT_EQ(x, y) \
+ do { \
+ if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
+ } while (0)
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
new file mode 100644
index 0000000000..8977d27f73
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -0,0 +1,343 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite LSTM op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LSTMOpModel : public SingleOpModel {
+ public:
+ LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
+ bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+ cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(TensorType_FLOAT32);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ scratch_buffer_ = AddOutput(TensorType_FLOAT32);
+ // TODO(ghodrat): Modify these states when we have a permanent solution for
+ // persistent buffer.
+ output_state_ = AddOutput(TensorType_FLOAT32);
+ cell_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
+ CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
+ cell_clip, proj_clip)
+ .Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void ResetOutputState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(output_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void ResetCellState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(cell_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ void Verify() {
+ auto model = tflite::UnPackModel(builder_.GetBufferPointer());
+ EXPECT_NE(model, nullptr);
+ }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ private:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_;
+ int output_state_;
+ int cell_state_;
+ int scratch_buffer_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+
+TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+
+ LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ // Verify the model by unpacking it.
+ lstm.Verify();
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
new file mode 100644
index 0000000000..3a60274524
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+
+namespace tflite {
+
+inline int ComputePadding(int stride, int in_size, int filter_size,
+ int out_size) {
+ int padding = ((out_size - 1) * stride + filter_size - in_size) / 2;
+ return padding > 0 ? padding : 0;
+}
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
new file mode 100644
index 0000000000..b798801108
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -0,0 +1,355 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pooling {
+
+// This file has two implementation of each pooling op.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+enum PoolType {
+ kAverage,
+ kMax,
+ kL2,
+};
+
+struct OpData {
+ TfLitePaddingValues padding;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to carry information from Prepare() to
+ // Eval().
+ return new OpData;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+template <PoolType pool_type>
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ int batches = input->dims->data[0];
+ int height = input->dims->data[1];
+ int width = input->dims->data[2];
+ int channels_out = input->dims->data[3];
+
+ // Matching GetWindowedOutputSize in TensorFlow.
+ auto padding = params->padding;
+ auto computeOutSize = [padding](int imageSize, int filterSize,
+ int stride) -> int {
+ return padding == kTfLitePaddingSame
+ ? (imageSize + stride - 1) / stride
+ : padding == kTfLitePaddingValid
+ ? (imageSize - filterSize + stride) / stride
+ : 0;
+ };
+
+ int outWidth =
+ computeOutSize(width, params->filter_width, params->stride_width);
+ int outHeight =
+ computeOutSize(height, params->filter_height, params->stride_height);
+
+ data->padding.height = ComputePadding(params->stride_height, height,
+ params->filter_height, outHeight);
+ data->padding.width = ComputePadding(params->stride_width, width,
+ params->filter_width, outWidth);
+
+ if (input->type == kTfLiteUInt8) {
+ if (pool_type == kAverage || pool_type == kMax) {
+ TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point,
+ output->params.zero_point);
+ }
+ if (pool_type == kL2) {
+ // We currently don't have a quantized implementation of L2Pool
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ }
+ }
+
+ TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
+ outputSize->data[0] = batches;
+ outputSize->data[1] = outHeight;
+ outputSize->data[2] = outWidth;
+ outputSize->data[3] = channels_out;
+ return context->ResizeTensor(context, output, outputSize);
+}
+
+template <KernelType kernel_type>
+void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* output) {
+ float activation_min, activation_max;
+ CalculateActivationRangeFloat(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_AVERAGE_POOL(type) \
+ type::AveragePool( \
+ GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
+ params->stride_height, data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_AVERAGE_POOL(reference_ops);
+ } else {
+ TF_LITE_AVERAGE_POOL(optimized_ops);
+ }
+#undef TF_LITE_AVERAGE_POOL
+}
+
+template <KernelType kernel_type>
+void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* output) {
+ int32_t activation_min;
+ int32_t activation_max;
+ CalculateActivationRangeUint8(params->activation, output, &activation_min,
+ &activation_max);
+#define TF_LITE_AVERAGE_POOL(type) \
+ type::AveragePool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, \
+ activation_min, activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_AVERAGE_POOL(reference_ops);
+ } else {
+ TF_LITE_AVERAGE_POOL(optimized_ops);
+ }
+#undef TF_LITE_AVERAGE_POOL
+}
+
+template <KernelType kernel_type>
+void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* output) {
+ float activation_min, activation_max;
+ CalculateActivationRangeFloat(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_MAX_POOL(type) \
+ type::MaxPool( \
+ GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
+ params->stride_height, data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MAX_POOL(reference_ops);
+ } else {
+ TF_LITE_MAX_POOL(optimized_ops);
+ }
+#undef TF_LITE_MAX_POOL
+}
+
+template <KernelType kernel_type>
+void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* output) {
+ int32_t activation_min;
+ int32_t activation_max;
+ CalculateActivationRangeUint8(params->activation, output, &activation_min,
+ &activation_max);
+#define TF_LITE_MAX_POOL(type) \
+ type::MaxPool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_MAX_POOL(reference_ops);
+ } else {
+ TF_LITE_MAX_POOL(optimized_ops);
+ }
+#undef TF_LITE_MAX_POOL
+}
+
+template <KernelType kernel_type>
+void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* output) {
+ float activation_min, activation_max;
+ CalculateActivationRangeFloat(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_L2_POOL(type) \
+ type::L2Pool( \
+ GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
+ params->stride_height, data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ if (kernel_type == kReference) {
+ TF_LITE_L2_POOL(reference_ops);
+ } else {
+ TF_LITE_L2_POOL(optimized_ops);
+ }
+#undef TF_LITE_L2_POOL
+}
+
+#undef TF_LITE_KERNEL_TYPE_DISPATCH
+
+template <KernelType kernel_type>
+TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ AverageEvalFloat<kernel_type>(context, node, params, data, input, output);
+ break;
+ case kTfLiteUInt8:
+ AverageEvalQuantized<kernel_type>(context, node, params, data, input,
+ output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ MaxEvalFloat<kernel_type>(context, node, params, data, input, output);
+ break;
+ case kTfLiteUInt8:
+ MaxEvalQuantized<kernel_type>(context, node, params, data, input, output);
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ L2EvalFloat<kernel_type>(context, node, params, data, input, output);
+ break;
+ case kTfLiteUInt8:
+ // We don't have a quantized implementation, so just fall through to the
+ // 'default' case.
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace pooling
+
+TfLiteRegistration* Register_AVERAGE_POOL_REF() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kAverage>,
+ pooling::AverageEval<pooling::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MAX_POOL_REF() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kMax>,
+ pooling::MaxEval<pooling::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2_POOL_REF() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kL2>,
+ pooling::L2Eval<pooling::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_AVERAGE_POOL_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ pooling::Init, pooling::Free, pooling::GenericPrepare<pooling::kAverage>,
+ pooling::AverageEval<pooling::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_MAX_POOL_GENERIC_OPT() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kMax>,
+ pooling::MaxEval<pooling::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_L2_POOL_GENERIC_OPT() {
+ static TfLiteRegistration r = {pooling::Init, pooling::Free,
+ pooling::GenericPrepare<pooling::kL2>,
+ pooling::L2Eval<pooling::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_AVERAGE_POOL_2D() {
+ return Register_AVERAGE_POOL_GENERIC_OPT();
+}
+
+TfLiteRegistration* Register_MAX_POOL_2D() {
+ return Register_MAX_POOL_GENERIC_OPT();
+}
+
+TfLiteRegistration* Register_L2_POOL_2D() {
+ return Register_L2_POOL_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc
new file mode 100644
index 0000000000..e1b51ec7d5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pooling_test.cc
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BasePoolingOpModel : public SingleOpModel {
+ public:
+ // TODO(ahentz): Also test different activation types, bias, padding types,
+ // stride values.
+ BasePoolingOpModel(BuiltinOperator type, const TensorData& input,
+ int filter_width, int filter_height,
+ const TensorData& output) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ type, BuiltinOptions_Pool2DOptions,
+ CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
+ filter_height, ActivationFunctionType_NONE)
+ .Union());
+
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatPoolingOpModel : public BasePoolingOpModel {
+ public:
+ using BasePoolingOpModel::BasePoolingOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedPoolingOpModel : public BasePoolingOpModel {
+ public:
+ using BasePoolingOpModel::BasePoolingOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+TEST(FloatPoolingOpTest, AveragePool) {
+ FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
+}
+
+TEST(QuantizedPoolingOpTest, AveragePool) {
+ // Choose the input ranges carefully so that the dequantized output matches
+ // the results of the float model above.
+ QuantizedPoolingOpModel m(
+ BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_UINT8, {}, 0, 15.9375});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({2.75, 5.75})));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92}));
+}
+
+TEST(FloatPoolingOpTest, MaxPool) {
+ FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
+}
+
+TEST(QuantizedPoolingOpTest, MaxPool) {
+ // Choose the input ranges carefully so that the dequantized output matches
+ // the results of the float model above.
+ QuantizedPoolingOpModel m(
+ BuiltinOperator_MAX_POOL_2D,
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_UINT8, {}, 0, 15.9375});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({6, 10})));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160}));
+}
+
+TEST(FloatPoolingOpTest, L2Pool) {
+ FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
new file mode 100644
index 0000000000..ca7a0dd194
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/register.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_RELU();
+TfLiteRegistration* Register_RELU1();
+TfLiteRegistration* Register_RELU6();
+TfLiteRegistration* Register_TANH();
+TfLiteRegistration* Register_LOGISTIC();
+TfLiteRegistration* Register_AVERAGE_POOL_2D();
+TfLiteRegistration* Register_MAX_POOL_2D();
+TfLiteRegistration* Register_L2_POOL_2D();
+TfLiteRegistration* Register_CONV_2D();
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Register_SVDF();
+TfLiteRegistration* Register_RNN();
+TfLiteRegistration* Register_EMBEDDING_LOOKUP();
+TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE();
+TfLiteRegistration* Register_FULLY_CONNECTED();
+TfLiteRegistration* Register_LSH_PROJECTION();
+TfLiteRegistration* Register_HASHTABLE_LOOKUP();
+TfLiteRegistration* Register_SOFTMAX();
+TfLiteRegistration* Register_CONCATENATION();
+TfLiteRegistration* Register_ADD();
+TfLiteRegistration* Register_MUL();
+TfLiteRegistration* Register_L2_NORMALIZATION();
+TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION();
+TfLiteRegistration* Register_LSTM();
+TfLiteRegistration* Register_RESHAPE();
+TfLiteRegistration* Register_RESIZE_BILINEAR();
+TfLiteRegistration* Register_SKIP_GRAM();
+TfLiteRegistration* Register_SPACE_TO_DEPTH();
+
+BuiltinOpResolver::BuiltinOpResolver() {
+ AddBuiltin(BuiltinOperator_RELU, Register_RELU());
+ AddBuiltin(BuiltinOperator_RELU1, Register_RELU1());
+ AddBuiltin(BuiltinOperator_RELU6, Register_RELU6());
+ AddBuiltin(BuiltinOperator_TANH, Register_TANH());
+ AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC());
+ AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
+ AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
+ AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
+ AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
+ AddBuiltin(BuiltinOperator_RNN, Register_RNN());
+ AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP());
+ AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
+ Register_EMBEDDING_LOOKUP_SPARSE());
+ AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED());
+ AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
+ AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
+ AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
+ AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION());
+ AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+ AddBuiltin(BuiltinOperator_MUL, Register_MUL());
+ AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
+ AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ Register_LOCAL_RESPONSE_NORMALIZATION());
+ AddBuiltin(BuiltinOperator_LSTM, Register_LSTM());
+ AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
+ AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR());
+ AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
+ AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
+}
+
+TfLiteRegistration* BuiltinOpResolver::FindOp(
+ tflite::BuiltinOperator op) const {
+ auto it = builtins_.find(op);
+ return it != builtins_.end() ? it->second : nullptr;
+}
+
+TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const {
+ auto it = custom_ops_.find(op);
+ return it != custom_ops_.end() ? it->second : nullptr;
+}
+
+void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = op;
+ builtins_.insert(std::make_pair(op, registration));
+}
+
+void BuiltinOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = BuiltinOperator_CUSTOM;
+ custom_ops_.insert(std::make_pair(std::string(name), registration));
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
new file mode 100644
index 0000000000..28f5e0fcc8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+class BuiltinOpResolver : public OpResolver {
+ public:
+ BuiltinOpResolver();
+ TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
+ TfLiteRegistration* FindOp(const char* op) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
+ void AddCustom(const char* name, TfLiteRegistration* registration);
+
+ private:
+ struct BuiltinOperatorHasher {
+ size_t operator()(const tflite::BuiltinOperator& x) const {
+ return std::hash<size_t>()(static_cast<size_t>(x));
+ }
+ };
+ std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*,
+ BuiltinOperatorHasher>
+ builtins_;
+ std::unordered_map<std::string, TfLiteRegistration*> custom_ops_;
+};
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
new file mode 100644
index 0000000000..f3e6ddc9f4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace reshape {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
+
+ // TODO(ahentz): we are often given a tensor with the shape but we only pay
+ // attention to what the shape specified in 'params'.
+ TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Tensorflow's Reshape allows one of the shape components to have the
+ // special -1 value, meaning it will be calculated automatically based on the
+ // input. Here we calculate what that dimension should be so that the number
+ // of output elements in the same as the number of input elements.
+ int num_input_elements = 1;
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ num_input_elements *= SizeOfDimension(input, i);
+ }
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
+ int num_output_elements = 1;
+ int strech_dim = -1;
+ for (int i = 0; i < params->num_dimensions; ++i) {
+ int value = params->shape[i];
+ if (value == -1) {
+ TF_LITE_ENSURE_EQ(context, strech_dim, -1);
+ strech_dim = i;
+ } else {
+ num_output_elements *= value;
+ output_size->data[i] = value;
+ }
+ }
+ if (strech_dim != -1) {
+ output_size->data[strech_dim] = num_input_elements / num_output_elements;
+ num_output_elements *= output_size->data[strech_dim];
+ }
+
+ TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ memcpy(output->data.raw, input->data.raw, input->bytes);
+
+ return kTfLiteOk;
+}
+
+} // namespace reshape
+
+TfLiteRegistration* Register_RESHAPE() {
+ static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare,
+ reshape::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc
new file mode 100644
index 0000000000..59ce7d5648
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/reshape_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ReshapeOpModel : public SingleOpModel {
+ public:
+ ReshapeOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> new_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
+ CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
+ .Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ReshapeOpTest, MismatchedDimensions) {
+ EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}),
+ "num_input_elements != num_output_elements");
+}
+
+TEST(ReshapeOpTest, TooManyDimensions) {
+ EXPECT_DEATH(
+ ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}),
+ "Found too many dimensions");
+}
+
+TEST(ReshapeOpTest, TooManySpecialDimensions) {
+ EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}),
+ "strech_dim != -1");
+}
+
+TEST(ReshapeOpTest, SimpleTest) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
+}
+
+TEST(ReshapeOpTest, WithStretchDimension) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
new file mode 100644
index 0000000000..1613c9a89f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -0,0 +1,129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace resize_bilinear {
+
+// This file has three implementation of RESIZE_BILINEAR.
+enum KernelType {
+ kReference,
+ kGenericOptimized, // Neon-free
+ kNeonOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(ahentz): Our current implementations rely on the inputs being 4D.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ // TODO(ahentz): Our current implementations only support float32.
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = params->new_height;
+ output_size->data[2] = params->new_width;
+ output_size->data[3] = input->dims->data[3];
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // We have to fake a tensor here, to satisfy ResizeBilinear().
+ int32 output_size_data[2] = {params->new_height, params->new_width};
+
+ if (output->type == kTfLiteFloat32) {
+#define TF_LITE_RESIZE_BILINEAR(type) \
+ type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
+ output_size_data, GetTensorDims({1, 1, 1, 2}), \
+ GetTensorData<float>(output), GetTensorDims(output))
+
+ if (kernel_type == kReference) {
+ TF_LITE_RESIZE_BILINEAR(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
+ TF_LITE_RESIZE_BILINEAR(optimized_ops);
+ }
+#undef TF_LITE_RESIZE_BILINEAR
+ } else {
+ context->ReportError(context, "Inputs and outputs not all float types.");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace resize_bilinear
+
+TfLiteRegistration* Register_RESIZE_BILINEAR_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, resize_bilinear::Prepare,
+ resize_bilinear::Eval<resize_bilinear::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, resize_bilinear::Prepare,
+ resize_bilinear::Eval<resize_bilinear::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, resize_bilinear::Prepare,
+ resize_bilinear::Eval<resize_bilinear::kNeonOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_RESIZE_BILINEAR() {
+#ifdef USE_NEON
+ return Register_RESIZE_BILINEAR_NEON_OPT();
+#else
+ return Register_RESIZE_BILINEAR_GENERIC_OPT();
+#endif
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
new file mode 100644
index 0000000000..0257c0b557
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ResizeBilinearOpModel : public SingleOpModel {
+ public:
+ ResizeBilinearOpModel(std::initializer_list<int> input_shape, int new_height,
+ int new_width) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions,
+ CreateResizeBilinearOptions(builder_, new_height, new_width).Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ResizeBilinearOpTest, HorizontalResize) {
+ ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3);
+ m.SetInput({3, 6});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+}
+
+TEST(ResizeBilinearOpTest, VerticalResize) {
+ ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1);
+ m.SetInput({3, 9});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
+ ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3);
+ m.SetInput({
+ 3, 6, //
+ 9, 12 //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
+ ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3);
+ m.SetInput({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 10, 16 //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
+ ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3);
+ m.SetInput({
+ 3, 4, 6, 10, //
+ 9, 10, 12, 16, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
new file mode 100644
index 0000000000..c90a15b3a2
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Generate a list of skip grams from an input.
+//
+// Options:
+// ngram_size: num of words for each output item.
+// max_skip_size: max num of words to skip.
+// The op generates ngrams when it is 0.
+// include_all_ngrams: include all ngrams with size up to ngram_size.
+//
+// Input:
+// A string tensor to generate n-grams.
+// Dim = {1}
+//
+// Output:
+// A list of strings, each of which contains ngram_size words.
+// Dim = {num_ngram}
+
+#include <ctype.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString);
+ TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString);
+ return kTfLiteOk;
+}
+
+bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) {
+ if (size <= 0) {
+ return false;
+ }
+ if (params->include_all_ngrams) {
+ return size <= params->ngram_size;
+ } else {
+ return size == params->ngram_size;
+ }
+}
+
+bool ShouldStepInRecursion(const TfLiteSkipGramParams* params,
+ const std::vector<int>& stack, int stack_idx,
+ int num_words) {
+ // If current stack size and next word enumeration are within valid range.
+ if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) {
+ // If this stack is empty, step in for first word enumeration.
+ if (stack_idx == 0) {
+ return true;
+ }
+ // If next word enumeration are within the range of max_skip_size.
+ // NOTE: equivalent to
+ // next_word_idx = stack[stack_idx] + 1
+ // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
+ if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) {
+ return true;
+ }
+ }
+ return false;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSkipGramParams*>(node->builtin_data);
+
+ // Split sentence to words.
+ std::vector<StringRef> words;
+ tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0);
+ int prev_idx = 0;
+ for (int i = 1; i < strref.len; i++) {
+ if (isspace(*(strref.str + i))) {
+ if (i > prev_idx && !isspace(*(strref.str + prev_idx))) {
+ words.push_back({strref.str + prev_idx, i - prev_idx});
+ }
+ prev_idx = i + 1;
+ }
+ }
+ if (strref.len > prev_idx) {
+ words.push_back({strref.str + prev_idx, strref.len - prev_idx});
+ }
+
+ // Generate n-grams recursively.
+ tflite::DynamicBuffer buf;
+ if (words.size() < params->ngram_size) {
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+ }
+
+ // Stack stores the index of word used to generate ngram.
+ // The size of stack is the size of ngram.
+ std::vector<int> stack(params->ngram_size, 0);
+ // Stack index that indicates which depth the recursion is operating at.
+ int stack_idx = 1;
+ int num_words = words.size();
+
+ while (stack_idx >= 0) {
+ if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) {
+ // When current depth can fill with a new word
+ // and the new word is within the max range to skip,
+ // fill this word to stack, recurse into next depth.
+ stack[stack_idx]++;
+ stack_idx++;
+ if (stack_idx < params->ngram_size) {
+ stack[stack_idx] = stack[stack_idx - 1];
+ }
+ } else {
+ if (ShouldIncludeCurrentNgram(params, stack_idx)) {
+ // Add n-gram to tensor buffer when the stack has filled with enough
+ // words to generate the ngram.
+ std::vector<StringRef> gram(stack_idx);
+ for (int i = 0; i < stack_idx; i++) {
+ gram[i] = words[stack[i]];
+ }
+ buf.AddJoinedString(gram, ' ');
+ }
+ // When current depth cannot fill with a valid new word,
+ // and not in last depth to generate ngram,
+ // step back to previous depth to iterate to next possible word.
+ stack_idx--;
+ }
+ }
+
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteRegistration* Register_SKIP_GRAM() {
+ static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc
new file mode 100644
index 0000000000..e7f6bc904b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc
@@ -0,0 +1,257 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!";
+
+class SkipGramOp : public SingleOpModel {
+ public:
+ SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) {
+ input_ = AddInput(TensorType_STRING);
+ output_ = AddOutput(TensorType_STRING);
+
+ SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions,
+ CreateSkipGramOptions(builder_, ngram_size, max_skip_size,
+ include_all_ngrams)
+ .Union());
+ BuildInterpreter({{1}});
+ }
+ void SetInput(const string& content) {
+ PopulateStringTensor(input_, {content});
+ }
+
+ std::vector<string> GetOutput() {
+ std::vector<string> ans;
+ TfLiteTensor* tensor = interpreter_->tensor(output_);
+
+ int num = GetStringCount(tensor);
+ for (int i = 0; i < num; i++) {
+ StringRef strref = GetString(tensor, i);
+ ans.push_back(string(strref.str, strref.len));
+ }
+ return ans;
+ }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(SkipGramTest, TestUnigram) {
+ SkipGramOp m(1, 0, false);
+
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray(
+ {"The", "quick", "brown", "fox", "jumps",
+ "over", "the", "lazy", "dog!"}));
+}
+
+TEST(SkipGramTest, TestBigram) {
+ SkipGramOp m(2, 0, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllBigram) {
+ SkipGramOp m(2, 0, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the",
+ "lazy", "dog!",
+ // Bigram
+ "The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllTrigram) {
+ SkipGramOp m(3, 0, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the",
+ "lazy", "dog!",
+ // Bigram
+ "The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!",
+ // Trigram
+ "The quick brown", "quick brown fox", "brown fox jumps",
+ "fox jumps over", "jumps over the", "over the lazy",
+ "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip1Bigram) {
+ SkipGramOp m(2, 1, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "The brown", "quick brown", "quick fox", "brown fox",
+ "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the",
+ "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip2Bigram) {
+ SkipGramOp m(2, 2, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "The brown", "The fox", "quick brown",
+ "quick fox", "quick jumps", "brown fox", "brown jumps",
+ "brown over", "fox jumps", "fox over", "fox the",
+ "jumps over", "jumps the", "jumps lazy", "over the",
+ "over lazy", "over dog!", "the lazy", "the dog!",
+ "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip1Trigram) {
+ SkipGramOp m(3, 1, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick brown", "The quick fox", "The brown fox",
+ "The brown jumps", "quick brown fox", "quick brown jumps",
+ "quick fox jumps", "quick fox over", "brown fox jumps",
+ "brown fox over", "brown jumps over", "brown jumps the",
+ "fox jumps over", "fox jumps the", "fox over the",
+ "fox over lazy", "jumps over the", "jumps over lazy",
+ "jumps the lazy", "jumps the dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip2Trigram) {
+ SkipGramOp m(3, 2, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick brown", "The quick fox", "The quick jumps",
+ "The brown fox", "The brown jumps", "The brown over",
+ "The fox jumps", "The fox over", "The fox the",
+ "quick brown fox", "quick brown jumps", "quick brown over",
+ "quick fox jumps", "quick fox over", "quick fox the",
+ "quick jumps over", "quick jumps the", "quick jumps lazy",
+ "brown fox jumps", "brown fox over", "brown fox the",
+ "brown jumps over", "brown jumps the", "brown jumps lazy",
+ "brown over the", "brown over lazy", "brown over dog!",
+ "fox jumps over", "fox jumps the", "fox jumps lazy",
+ "fox over the", "fox over lazy", "fox over dog!",
+ "fox the lazy", "fox the dog!", "jumps over the",
+ "jumps over lazy", "jumps over dog!", "jumps the lazy",
+ "jumps the dog!", "jumps lazy dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllSkip2Trigram) {
+ SkipGramOp m(3, 2, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy",
+ "dog!",
+ // Bigram
+ "The quick", "The brown", "The fox", "quick brown", "quick fox",
+ "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps",
+ "fox over", "fox the", "jumps over", "jumps the", "jumps lazy",
+ "over the", "over lazy", "over dog!", "the lazy", "the dog!",
+ "lazy dog!",
+ // Trigram
+ "The quick brown", "The quick fox", "The quick jumps",
+ "The brown fox", "The brown jumps", "The brown over",
+ "The fox jumps", "The fox over", "The fox the", "quick brown fox",
+ "quick brown jumps", "quick brown over", "quick fox jumps",
+ "quick fox over", "quick fox the", "quick jumps over",
+ "quick jumps the", "quick jumps lazy", "brown fox jumps",
+ "brown fox over", "brown fox the", "brown jumps over",
+ "brown jumps the", "brown jumps lazy", "brown over the",
+ "brown over lazy", "brown over dog!", "fox jumps over",
+ "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy",
+ "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the",
+ "jumps over lazy", "jumps over dog!", "jumps the lazy",
+ "jumps the dog!", "jumps lazy dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSingleWord) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput("Hi");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre("Hi"));
+}
+
+TEST(SkipGramTest, TestWordsLessThanGram) {
+ SkipGramOp m(3, 1, false);
+ m.SetInput("Hi hi");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), std::vector<string>());
+}
+
+TEST(SkipGramTest, TestEmptyInput) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput("");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre());
+}
+
+TEST(SkipGramTest, TestWhitespaceInput) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput(" ");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre());
+}
+
+TEST(SkipGramTest, TestInputWithExtraSpace) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput(" Hello world ! ");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!"));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
new file mode 100644
index 0000000000..ec8ec03b0d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -0,0 +1,143 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite SOFTMAX op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+class SoftmaxOpModel : public SingleOpModel {
+ public:
+ SoftmaxOpModel(int batches, int size, float beta)
+ : batches_(batches), input_size_(size), beta_(beta) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
+ CreateSoftmaxOptions(builder_, beta_).Union());
+ BuildInterpreter({{batches_, input_size_}});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+
+ int batches_;
+ int input_size_;
+ float beta_;
+};
+
+TEST(SoftmaxOpTest, SimpleTest) {
+ SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0);
+ m.SetInput({
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
+ 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231},
+ 1e-6)));
+}
+
+TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
+ const int batch_size = 2;
+ const int input_size = 5;
+ const float beta = 1.0;
+ static float input_buffer[] = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1
+ };
+
+ SoftmaxOpModel m(batch_size, input_size, beta);
+
+ m.SetInput(0, input_buffer, input_buffer + input_size * batch_size);
+
+ m.Invoke();
+
+ std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
+ static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
+ {1, 0, 0, input_size}};
+ tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
+ output_buffer.get(), input_dims);
+
+ std::vector<float> expected;
+ expected.insert(expected.end(), output_buffer.get(),
+ output_buffer.get() + input_size * batch_size);
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6)));
+}
+
+TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
+ const int batch_size = 2;
+ const int input_size = 5;
+ const float beta = 0.5;
+ static float input_buffer[] = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1
+ };
+
+ SoftmaxOpModel m(batch_size, input_size, beta);
+
+ m.SetInput(0, input_buffer, input_buffer + input_size * batch_size);
+
+ m.Invoke();
+
+ std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
+ static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
+ {1, 0, 0, input_size}};
+ tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
+ output_buffer.get(), input_dims);
+
+ std::vector<float> expected;
+ expected.insert(expected.end(), output_buffer.get(),
+ output_buffer.get() + input_size * batch_size);
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
new file mode 100644
index 0000000000..cb2e509c98
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace space_to_depth {
+
+// This file has two implementation of SpaceToDepth. Note that SpaceToDepth
+// only works on 4D tensors.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+
+ auto data_type = output->type;
+ TF_LITE_ENSURE(context,
+ data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
+ data_type == kTfLiteInt32 || data_type == kTfLiteInt64);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ const int block_size = params->block_size;
+ const int input_height = input->dims->data[1];
+ const int input_width = input->dims->data[2];
+ int output_height = input_height / block_size;
+ int output_width = input_width / block_size;
+
+ TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size);
+ TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ output_size->data[1] = output_height;
+ output_size->data[2] = output_width;
+ output_size->data[3] = input->dims->data[3] * block_size * block_size;
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ type::SpaceToDepth<scalar>( \
+ GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
+ GetTensorData<scalar>(output), GetTensorDims(output))
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, float);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t);
+ } else {
+ TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_SPACE_TO_DEPTH
+
+ return kTfLiteOk;
+}
+
+} // namespace space_to_depth
+
+TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, space_to_depth::Prepare,
+ space_to_depth::Eval<space_to_depth::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, space_to_depth::Prepare,
+ space_to_depth::Eval<space_to_depth::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_SPACE_TO_DEPTH() {
+ return Register_SPACE_TO_DEPTH_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc
new file mode 100644
index 0000000000..911f08a92c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class SpaceToDepthOpModel : public SingleOpModel {
+ public:
+ SpaceToDepthOpModel(const TensorData& tensor_data, int block_size) {
+ input_ = AddInput(tensor_data);
+ output_ = AddOutput(tensor_data);
+ SetBuiltinOp(BuiltinOperator_SPACE_TO_DEPTH,
+ BuiltinOptions_SpaceToDepthOptions,
+ CreateSpaceToDepthOptions(builder_, block_size).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ template <typename T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(SpaceToDepthOpModel, BadBlockSize) {
+ EXPECT_DEATH(SpaceToDepthOpModel({TensorType_FLOAT32, {1, 2, 2, 1}}, 3),
+ "Cannot allocate tensors");
+}
+
+TEST(SpaceToDepthOpModel, Float32) {
+ SpaceToDepthOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, 2);
+ m.SetInput<float>({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 8));
+}
+
+TEST(SpaceToDepthOpModel, Uint8) {
+ SpaceToDepthOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, 2);
+ m.SetInput<uint8_t>({1, 2, 3, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(SpaceToDepthOpModel, Int32) {
+ SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2);
+ m.SetInput<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int32_t>(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 12));
+}
+
+TEST(SpaceToDepthOpModel, Int64) {
+ SpaceToDepthOpModel m({TensorType_INT64, {1, 4, 4, 1}}, 2);
+ m.SetInput<int64_t>({1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<int64_t>(),
+ ElementsAreArray(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 4));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
new file mode 100644
index 0000000000..dd414d53bd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -0,0 +1,224 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace svdf {
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsFeatureTensor = 1;
+constexpr int kWeightsTimeTensor = 2;
+constexpr int kBiasTensor = 3;
+constexpr int kStateTensor = 0;
+constexpr int KOutputTensor = 1;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* weights_feature =
+ &context->tensors[node->inputs->data[kWeightsFeatureTensor]];
+ TfLiteTensor* weights_time =
+ &context->tensors[node->inputs->data[kWeightsTimeTensor]];
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int rank = params->rank;
+ const int batch_size = input->dims->data[0];
+ const int num_filters = weights_feature->dims->data[0];
+ TF_LITE_ASSERT_EQ(num_filters % rank, 0);
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+ TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]);
+ TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters);
+
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ if (bias) {
+ TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
+ }
+
+ TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]];
+
+ // Resize state.
+ // For each batch, the state is a 2-D tensor: memory_size * num_filters
+ // The left most column is used to save current cycle activation.
+ // The right most column is used to save temporary output which will be
+ // reduced to num_units outputs.
+ TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
+ state_size_array->data[0] = batch_size;
+ state_size_array->data[1] = memory_size * num_filters;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, state, state_size_array));
+
+ // Mark state as a persistent tensor.
+ state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size_array));
+
+ // Resize scratch.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+
+ TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2);
+ scratch_size_array->data[0] = batch_size;
+ scratch_size_array->data[1] = num_filters;
+
+ TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]];
+ scratch_tensor->type = input->type;
+ scratch_tensor->allocation_type = kTfLiteArenaRw;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor,
+ scratch_size_array));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* weights_feature =
+ &context->tensors[node->inputs->data[kWeightsFeatureTensor]];
+ TfLiteTensor* weights_time =
+ &context->tensors[node->inputs->data[kWeightsTimeTensor]];
+
+ TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]];
+ TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]];
+
+ TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+
+ const int rank = params->rank;
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int num_filters = weights_feature->dims->data[0];
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+
+ // Clear the activation (state left most column).
+ // TODO(ghodrat): Add a test which initialize state with invalid values in
+ // left most column and make sure it passes.
+ for (int b = 0; b < batch_size; b++) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ for (int c = 0; c < num_filters; c++) {
+ float* state_ptr = state_ptr_batch + c * memory_size;
+ state_ptr[memory_size - 1] = 0.0;
+ }
+ }
+
+ // Compute conv1d(inputs, weights_feature).
+ // The state left most column is used to save current cycle activation. This
+ // is achieved by starting at state->data.f[memory_size - 1] and having the
+ // stride equal to memory_size.
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_feature->data.f, num_filters, input_size, input->data.f,
+ batch_size, &state->data.f[memory_size - 1], memory_size);
+
+ // Compute matmul(state, weights_time).
+ // The right most column is used to save temporary output (with the size of
+ // num_filters). This is achieved by starting at state->data.f and having the
+ // stride equal to memory_size.
+ for (int b = 0; b < batch_size; b++) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::BatchVectorBatchVectorDotProduct(
+ weights_time->data.f, state_ptr_batch, memory_size, num_filters,
+ scratch_ptr_batch, /*result_stride=*/1);
+ }
+
+ // Initialize output with bias if provided.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // Reduction sum
+ // TODO(ghodrat): Consider not reusing state for the temporary output, this
+ // way ReductionSum operates on row-vector instead of column vector.
+ for (int b = 0; b < batch_size; b++) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
+ num_units, rank);
+ }
+
+ // Apply activation.
+ for (int b = 0; b < batch_size; b++) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
+ params->activation, output_ptr_batch);
+ }
+
+ // Right shift the state.
+ for (int b = 0; b < batch_size; b++) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ for (int f = 0; f < num_filters; f++) {
+ tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
+ /*shift_value=*/0.0);
+ state_ptr_batch += memory_size;
+ }
+ }
+ return kTfLiteOk;
+}
+
+} // namespace svdf
+
+TfLiteRegistration* Register_SVDF() {
+ static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare,
+ svdf::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
new file mode 100644
index 0000000000..d956025e9d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -0,0 +1,312 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite SVDF op.
+
+#include <vector>
+#include <iomanip>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+static float svdf_input[] = {
+ 0.12609188, -0.46347019, -0.89598465,
+ 0.35867718, 0.36897406, 0.73463392,
+
+ 0.14278367, -1.64410412, -0.75222826,
+ -0.57290924, 0.12729003, 0.7567004,
+
+ 0.49837467, 0.19278903, 0.26584083,
+ 0.17660543, 0.52949083, -0.77931279,
+
+ -0.11186574, 0.13164264, -0.05349274,
+ -0.72674477, -0.5683046, 0.55900657,
+
+ -0.68892461, 0.37783599, 0.18263303,
+ -0.63690937, 0.44483393, -0.71817774,
+
+ -0.81299269, -0.86831826, 1.43940818,
+ -0.95760226, 1.82078898, 0.71135032,
+
+ -1.45006323, -0.82251364, -1.69082689,
+ -1.65087092, -1.89238167, 1.54172635,
+
+ 0.03966608, -0.24936394, -0.77526885,
+ 2.06740379, -1.51439476, 1.43768692,
+
+ 0.11771342, -0.23761693, -0.65898693,
+ 0.31088525, -1.55601168, -0.87661445,
+
+ -0.89477462, 1.67204106, -0.53235275,
+ -0.6230064, 0.29819036, 1.06939757,
+};
+
+static float svdf_golden_output_rank_1[] = {
+ 0.014899, -0.0517661, -0.143725, -0.00271883,
+ -0.03004015, 0.09565311, 0.1587342, 0.00784263,
+
+ 0.068281, -0.162217, -0.152268, 0.00323521,
+ 0.01582633, 0.03858774, -0.03001583, -0.02671271,
+
+ -0.0317821, -0.0333089, 0.0609602, 0.0333759,
+ -0.01432795, 0.05524484, 0.1101355, -0.02382665,
+
+ -0.00623099, -0.077701, -0.391193, -0.0136691,
+ -0.02333033, 0.02293761, 0.12338032, 0.04326871,
+
+ 0.201551, -0.164607, -0.179462, -0.0592739,
+ 0.01064911, -0.17503069, 0.07821996, -0.00224009,
+
+ 0.0886511, -0.0875401, -0.269283, 0.0281379,
+ -0.02282338, 0.09741908, 0.32973239, 0.12281385,
+
+ -0.201174, -0.586145, -0.628624, -0.0330412,
+ 0.24780814, -0.39304617, -0.22473189, 0.02589256,
+
+ -0.0839096, -0.299329, 0.108746, 0.109808,
+ 0.10084175, -0.06416984, 0.28936723, 0.0026358,
+
+ 0.419114, -0.237824, -0.422627, 0.175115,
+ -0.2314795, -0.18584411, -0.4228974, -0.12928449,
+
+ 0.36726, -0.522303, -0.456502, -0.175475,
+ 0.17012937, -0.34447709, 0.38505614, -0.28158101,
+};
+
+static float svdf_golden_output_rank_2[] = {
+ -0.09623547, -0.10193135, 0.11083051, -0.0347917,
+ 0.1141196, 0.12965347, -0.12652366, 0.01007236,
+
+ -0.16396809, -0.21247184, 0.11259045, -0.04156673,
+ 0.10132131, -0.06143532, -0.00924693, 0.10084561,
+
+ 0.01257364, 0.0506071, -0.19287863, -0.07162561,
+ -0.02033747, 0.22673416, 0.15487903, 0.02525555,
+
+ -0.1411963, -0.37054959, 0.01774767, 0.05867489,
+ 0.09607603, -0.0141301, -0.08995658, 0.12867066,
+
+ -0.27142537, -0.16955489, 0.18521598, -0.12528358,
+ 0.00331409, 0.11167502, 0.02218599, -0.07309391,
+
+ 0.09593632, -0.28361851, -0.0773851, 0.17199151,
+ -0.00075242, 0.33691186, -0.1536046, 0.16572715,
+
+ -0.27916506, -0.27626723, 0.42615682, 0.3225764,
+ -0.37472126, -0.55655634, -0.05013514, 0.289112,
+
+ -0.24418658, 0.07540751, -0.1940318, -0.08911639,
+ 0.00732617, 0.46737891, 0.26449674, 0.24888524,
+
+ -0.17225097, -0.54660404, -0.38795233, 0.08389944,
+ 0.07736043, -0.28260678, 0.15666828, 1.14949894,
+
+ -0.57454878, -0.64704704, 0.73235172, -0.34616736,
+ 0.21120001, -0.22927976, 0.02455296, -0.35906726,
+};
+
+// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
+class SVDFOpModel : public SingleOpModel {
+ public:
+ SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank)
+ : batches_(batches),
+ units_(units),
+ input_size_(input_size),
+ memory_size_(memory_size),
+ rank_(rank) {
+ input_ = AddInput(TensorType_FLOAT32);
+ weights_feature_ = AddInput(TensorType_FLOAT32);
+ weights_time_ = AddInput(TensorType_FLOAT32);
+ bias_ = AddNullInput();
+ state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
+ CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
+ BuildInterpreter({
+ {batches_, input_size_}, // Input tensor
+ {units_ * rank, input_size_}, // weights_feature tensor
+ {units_ * rank, memory_size_}, // weights_time tensor
+ {units_} // bias tensor
+ });
+ }
+
+ // Populates the weights_feature tensor.
+ void SetWeightsFeature(std::initializer_list<float> f) {
+ PopulateTensor(weights_feature_, f);
+ }
+
+ // Populates the weights_time tensor.
+ void SetWeightsTime(std::initializer_list<float> f) {
+ PopulateTensor(weights_time_, f);
+ }
+
+ // Populates the input tensor.
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ // Resets the state of SVDF op by filling it with 0's.
+ void ResetState() {
+ const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ // Extracts the output tensor from the SVDF op.
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ private:
+ int input_;
+ int weights_feature_;
+ int weights_time_;
+ int bias_;
+ int state_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+ int memory_size_;
+ int rank_;
+};
+
+TEST(SVDFOpTest, BlackBoxTestRank1) {
+ SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/1);
+ svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
+
+ svdf.ResetState();
+ const int svdf_num_batches = svdf.num_batches();
+ const int svdf_input_size = svdf.input_size();
+ const int svdf_num_units = svdf.num_units();
+ const int input_sequence_size =
+ sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF op
+ // and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf.SetInput(0, batch_start, batch_end);
+
+ svdf.Invoke();
+
+ float* golden_start =
+ svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches;
+ float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+TEST(SVDFOpTest, BlackBoxTestRank2) {
+ SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/2);
+ svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
+ 0.12416199, 0.15785322, 0.27901134, 0.3905206,
+ 0.21931258, -0.36137494, -0.10640851, 0.31053296,
+ -0.36118156, -0.0976817, -0.36916667, 0.22197971,
+ 0.15294972, 0.38031587, 0.27557442, 0.39635518,
+ -0.21580373, -0.06634006, -0.02702999, 0.27072677});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
+
+ -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
+ 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
+
+ -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
+ 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
+
+ -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
+ -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
+
+ 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
+ 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
+
+ svdf.ResetState();
+ const int svdf_num_batches = svdf.num_batches();
+ const int svdf_input_size = svdf.input_size();
+ const int svdf_num_units = svdf.num_units();
+ const int input_sequence_size =
+ sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF op
+ // and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf.SetInput(0, batch_start, batch_end);
+
+ svdf.Invoke();
+
+ float* golden_start =
+ svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
+ float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ }
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
new file mode 100644
index 0000000000..f716ba8741
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+#include "tensorflow/contrib/lite/version.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+
+using ::testing::FloatNear;
+using ::testing::Matcher;
+
+namespace {
+template <typename T>
+std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) {
+ // These are required by many quantized operations.
+ CHECK_LE(f_min, 0);
+ CHECK_GE(f_max, 0);
+ T q_min = std::numeric_limits<T>::min();
+ T q_max = std::numeric_limits<T>::max();
+ float range = q_max - q_min;
+ float scale = (f_max - f_min) / range;
+ int32_t zero_point = std::min(
+ q_max,
+ std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale))));
+ return {scale, zero_point};
+}
+} // namespace
+
+std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
+ float max_abs_error) {
+ std::vector<Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(FloatNear(v, max_abs_error));
+ }
+ return matchers;
+}
+
+int SingleOpModel::AddTensor(TensorData t) {
+ int id = tensors_.size();
+
+ // This is slightly different depending on whether we are adding a
+ // quantized or a regular tensor.
+ bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0);
+
+ flatbuffers::Offset<QuantizationParameters> q_params = 0;
+
+ if (is_quantized) {
+ if (t.min != 0 || t.max != 0) {
+ if (t.type == TensorType_UINT8) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<uint8_t>(t.min, t.max);
+ } else if (t.type == TensorType_INT32) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<int32_t>(t.min, t.max);
+ } else {
+ LOG(FATAL) << "No support for the requested quantized type";
+ }
+ t.min = 0;
+ t.max = 0;
+ }
+
+ q_params = CreateQuantizationParameters(
+ builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>({t.scale}),
+ builder_.CreateVector<int64_t>({t.zero_point}));
+ }
+
+ tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>({}),
+ t.type, /*buffer=*/0,
+ /*name=*/0, q_params));
+
+ tensor_data_[id] = t;
+
+ return id;
+}
+
+int SingleOpModel::AddInput(const TensorData& t) {
+ int id = AddTensor(t);
+ inputs_.push_back(id);
+ return id;
+}
+
+int SingleOpModel::AddNullInput() {
+ int id = kOptionalTensor;
+ inputs_.push_back(id);
+ return id;
+}
+
+int SingleOpModel::AddOutput(const TensorData& t) {
+ int id = AddTensor(t);
+ outputs_.push_back(id);
+ return id;
+}
+
+void SingleOpModel::SetBuiltinOp(BuiltinOperator type,
+ BuiltinOptions builtin_options_type,
+ flatbuffers::Offset<void> builtin_options) {
+ opcodes_.push_back(CreateOperatorCode(builder_, type, 0));
+ operators_.push_back(CreateOperator(
+ builder_, /*opcode_index=*/0, builder_.CreateVector<int32_t>(inputs_),
+ builder_.CreateVector<int32_t>(outputs_), builtin_options_type,
+ builtin_options,
+ /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS));
+}
+
+void SingleOpModel::SetCustomOp(
+ const string& name, const std::vector<uint8_t>& custom_option,
+ const std::function<TfLiteRegistration*()>& registeration) {
+ custom_registrations_[name] = registeration;
+ opcodes_.push_back(
+ CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data()));
+ operators_.push_back(CreateOperator(
+ builder_, /*opcode_index=*/0, builder_.CreateVector<int32_t>(inputs_),
+ builder_.CreateVector<int32_t>(outputs_), BuiltinOptions_NONE, 0,
+ builder_.CreateVector<uint8_t>(custom_option),
+ CustomOptionsFormat_FLEXBUFFERS));
+}
+
+void SingleOpModel::BuildInterpreter(
+ std::vector<std::vector<int>> input_shapes) {
+ auto opcodes = builder_.CreateVector(opcodes_);
+ auto operators = builder_.CreateVector(operators_);
+ auto tensors = builder_.CreateVector(tensors_);
+ auto inputs = builder_.CreateVector<int32_t>(inputs_);
+ auto outputs = builder_.CreateVector<int32_t>(outputs_);
+ // Create a single subgraph
+ std::vector<flatbuffers::Offset<SubGraph>> subgraphs;
+ auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators);
+ subgraphs.push_back(subgraph);
+ auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs);
+
+ std::vector<flatbuffers::Offset<Buffer>> buffers_vec;
+ auto buffers = builder_.CreateVector(buffers_vec);
+ auto description = builder_.CreateString("programmatic model");
+ builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+ subgraphs_flatbuffer, description, buffers));
+
+ auto* model = GetModel(builder_.GetBufferPointer());
+
+ ops::builtin::BuiltinOpResolver builtins;
+ for (const auto& reg : custom_registrations_) {
+ builtins.AddCustom(reg.first.data(), reg.second());
+ }
+ InterpreterBuilder(model, builtins)(&interpreter_);
+
+ CHECK(interpreter_ != nullptr);
+
+ int i = 0;
+ for (const auto& shape : input_shapes) {
+ int input_idx = interpreter_->inputs()[i++];
+ if (input_idx == kOptionalTensor) continue;
+ CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
+ }
+ CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
+ << "Cannot allocate tensors";
+}
+
+void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
+
+int32_t SingleOpModel::GetTensorSize(int index) const {
+ TfLiteTensor* t = interpreter_->tensor(index);
+ CHECK(t);
+ int total_size = 1;
+ for (int i = 0; i < t->dims->size; ++i) {
+ total_size *= t->dims->data[i];
+ }
+ return total_size;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
new file mode 100644
index 0000000000..e68e494661
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -0,0 +1,202 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+
+inline void LogToStderr() {
+#ifdef PLATFORM_GOOGLE
+ FLAGS_logtostderr = true;
+#endif
+}
+
+// A gmock matcher that check that elements of a float vector match to a given
+// tolerance.
+std::vector<::testing::Matcher<float>> ArrayFloatNear(
+ const std::vector<float>& values, float max_abs_error = 1e-5);
+
+template <typename T>
+inline std::vector<T> Quantize(const std::vector<float>& data, float scale,
+ int32_t zero_point) {
+ std::vector<T> q;
+ for (float f : data) {
+ q.push_back(std::max(
+ std::numeric_limits<T>::min(),
+ std::min(std::numeric_limits<T>::max(),
+ static_cast<T>(std::round(zero_point + (f / scale))))));
+ }
+ return q;
+}
+
+template <typename T>
+inline std::vector<float> Dequantize(const std::vector<T>& data, float scale,
+ int32_t zero_point) {
+ std::vector<float> f;
+ for (T q : data) {
+ f.push_back(scale * (q - zero_point));
+ }
+ return f;
+}
+
+// A test model that contains a single operator. All operator inputs and
+// output are external to the model, so the tests can directly access them.
+// Typical usage:
+// SingleOpModel m;
+// int a = m.AddInput({TensorType_FLOAT32, a_shape});
+// int b = m.AddInput({TensorType_FLOAT32, b_shape});
+// int c = m.AddOutput({TensorType_FLOAT32, {}});
+// m.SetBuiltinOp(...);
+// m.BuildInterpreter({GetShape(a), GetShape(b)});
+// m.PopulateTensor(a, {...});
+// m.PopulateTensor(b, {...});
+// m.Invoke();
+// EXPECT_THAT(m.ExtractVector<float>(c), ArrayFloatNear({...}));
+//
+
+// A helper struct to construct test tensors. This is particularly useful for
+// quantized tensor which must have their scale and zero_point defined before
+// the actual data is known. This mimics what happens in practice: quantization
+// parameters are calculate during training.
+struct TensorData {
+ TensorType type;
+ std::vector<int> shape;
+ float min;
+ float max;
+ float scale;
+ int32_t zero_point;
+};
+
+class SingleOpModel {
+ public:
+ SingleOpModel() {}
+ ~SingleOpModel() {}
+
+ // Copying or assignment is disallowed to simplify ownership semantics.
+ SingleOpModel(const SingleOpModel&) = delete;
+ SingleOpModel& operator=(const SingleOpModel&) = delete;
+
+ // Add a TensorType input tensor and return its index.
+ int AddInput(TensorType type) { return AddInput(TensorData{type}); }
+ int AddInput(const TensorData& t);
+
+ // Add a null input tensor (optional input) and return kOptionalTensor.
+ int AddNullInput();
+
+ // Add a TensorType output tensor and return its index.
+ int AddOutput(TensorType type) { return AddOutput(TensorData{type}); }
+ int AddOutput(const TensorData& t);
+
+ template <typename T>
+ void QuantizeAndPopulate(int index, std::initializer_list<float> data) {
+ TfLiteTensor* t = interpreter_->tensor(index);
+ auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
+ PopulateTensor(index, 0, q.data(), q.data() + q.size());
+ }
+
+ const std::vector<int>& GetShape(int id) { return tensor_data_.at(id).shape; }
+
+ float GetScale(int id) { return tensor_data_.at(id).scale; }
+ int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; }
+
+ // Define the operator in this model.
+ void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type,
+ flatbuffers::Offset<void> builtin_options);
+ void SetCustomOp(const string& name,
+ const std::vector<uint8_t>& custom_option,
+ const std::function<TfLiteRegistration*()>& registeration);
+
+ // Build the interpreter for this model. Also, resize and allocate all
+ // tensors given the shapes of the inputs.
+ void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
+
+ void Invoke();
+
+ void PopulateStringTensor(int index, const std::vector<string>& content) {
+ auto tensor = interpreter_->tensor(index);
+ DynamicBuffer buf;
+ for (const string& s : content) {
+ buf.AddString(s.data(), s.length());
+ }
+ buf.WriteToTensor(tensor);
+ }
+
+ // Populate the tensor given its index.
+ template <typename T>
+ void PopulateTensor(int index, std::initializer_list<T> data) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ CHECK(v) << "No tensor with index '" << index << "'.";
+ for (T f : data) {
+ *v = f;
+ ++v;
+ }
+ }
+
+ // Partially populate the tensor, starting at the given offset.
+ template <typename T>
+ void PopulateTensor(int index, int offset, T* begin, T* end) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ memcpy(v + offset, begin, (end - begin) * sizeof(T));
+ }
+
+ // Return a vector with the flattened contents of a tensor.
+ template <typename T>
+ std::vector<T> ExtractVector(int index) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ CHECK(v);
+ return std::vector<T>(v, v + GetTensorSize(index));
+ }
+
+ std::vector<int> GetTensorShape(int index) {
+ std::vector<int> result;
+ TfLiteTensor* t = interpreter_->tensor(index);
+ for (int i = 0; i < t->dims->size; ++i) {
+ result.push_back(t->dims->data[i]);
+ }
+ return result;
+ }
+
+ protected:
+ int32_t GetTensorSize(int index) const;
+
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+
+ private:
+ int AddTensor(TensorData t);
+
+ std::map<int, TensorData> tensor_data_;
+ std::vector<int32_t> inputs_;
+ std::vector<int32_t> outputs_;
+ std::vector<flatbuffers::Offset<Tensor>> tensors_;
+ std::vector<flatbuffers::Offset<OperatorCode>> opcodes_;
+ std::vector<flatbuffers::Offset<Operator>> operators_;
+ std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
new file mode 100644
index 0000000000..f8208f6f98
--- /dev/null
+++ b/tensorflow/contrib/lite/model.cc
@@ -0,0 +1,673 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <fcntl.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+const char* kEmptyTensorName = "";
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
+ const char* filename, ErrorReporter* error_reporter) {
+ std::unique_ptr<FlatBufferModel> model;
+ model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
+ /*use_nnapi=*/true));
+ if (!model->initialized()) model.reset();
+ return model;
+}
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
+ const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
+ std::unique_ptr<FlatBufferModel> model;
+ model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
+ if (!model->initialized()) model.reset();
+ return model;
+}
+
+FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
+ ErrorReporter* error_reporter, bool use_nnapi)
+ : error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {
+ if (mmap_file) {
+ if (use_nnapi && NNAPIExists())
+ allocation_ = new NNAPIAllocation(filename, error_reporter);
+ else
+ allocation_ = new MMAPAllocation(filename, error_reporter);
+ } else {
+ allocation_ = new FileCopyAllocation(filename, error_reporter);
+ }
+ if (!allocation_->valid()) return;
+ if (!CheckModelIdentifier()) return;
+
+ model_ = ::tflite::GetModel(allocation_->base());
+}
+
+bool FlatBufferModel::CheckModelIdentifier() const {
+ if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
+ const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
+ error_reporter_->Report(
+ "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
+ ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
+ return false;
+ }
+ return true;
+}
+
+FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter)
+ : error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {
+ allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
+ if (!allocation_->valid()) return;
+ model_ = ::tflite::GetModel(allocation_->base());
+}
+
+FlatBufferModel::~FlatBufferModel() { delete allocation_; }
+
+InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
+ const OpResolver& op_resolver)
+ : model_(model.GetModel()),
+ op_resolver_(op_resolver),
+ error_reporter_(model.error_reporter()),
+ allocation_(model.allocation()) {}
+
+InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter)
+ : model_(model),
+ op_resolver_(op_resolver),
+ error_reporter_(error_reporter ? error_reporter
+ : DefaultErrorReporter()) {}
+
+TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
+ TfLiteStatus status = kTfLiteOk;
+ auto opcodes = model_->operator_codes();
+ for (const OperatorCode* opcode : *opcodes) {
+ TfLiteRegistration* registration = nullptr;
+
+ if (opcode->builtin_code() != BuiltinOperator_CUSTOM) {
+ auto x = opcode->builtin_code();
+ flatbuffer_op_index_to_registration_types_.push_back(x);
+ registration = op_resolver_.FindOp(x);
+ if (registration == nullptr) {
+ error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
+ EnumNameBuiltinOperator(x));
+ status = kTfLiteError;
+ }
+ } else if (!opcode->custom_code()) {
+ error_reporter_->Report(
+ "Operator with builtin_code==0 has no custom_code.\n");
+ status = kTfLiteError;
+ } else {
+ const char* name = opcode->custom_code()->c_str();
+ registration = op_resolver_.FindOp(name);
+ flatbuffer_op_index_to_registration_types_.push_back(
+ BuiltinOperator_CUSTOM);
+ if (registration == nullptr) {
+ error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
+ status = kTfLiteError;
+ }
+ }
+ flatbuffer_op_index_to_registration_.push_back(registration);
+ }
+ return status;
+}
+
+namespace {
+template <class T>
+std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
+ std::vector<int> ret(flat_array->Length());
+ for (int i = 0; i < flat_array->Length(); i++) {
+ ret[i] = flat_array->Get(i);
+ }
+ return ret;
+}
+
+// Allocate a structure using C malloc, but make sure the structure is a
+// POD structure that doesn't require constructors to run. The reason we do
+// this, is that Interpreter's C extension part will take ownership and wants
+// to use malloc() and free().
+template <class T>
+T* MallocPOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(malloc(sizeof(T)));
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+//
+// Returns memory that must be feed.
+void* ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter) {
+ auto parse_padding = [](Padding padding) {
+ switch (padding) {
+ case Padding_SAME:
+ return kTfLitePaddingSame;
+ case Padding_VALID:
+ return kTfLitePaddingValid;
+ }
+ return kTfLitePaddingUnknown;
+ };
+ auto parse_activation = [](ActivationFunctionType activation) {
+ switch (activation) {
+ case ActivationFunctionType_NONE:
+ return kTfLiteActNone;
+ case ActivationFunctionType_RELU:
+ return kTfLiteActRelu;
+ case ActivationFunctionType_RELU1:
+ return kTfLiteActRelu1;
+ case ActivationFunctionType_RELU6:
+ return kTfLiteActRelu6;
+ case ActivationFunctionType_TANH:
+ return kTfLiteActTanh;
+ case ActivationFunctionType_SIGN_BIT:
+ return kTfLiteActSignBit;
+ }
+ return kTfLiteActNone;
+ };
+ auto parseLSHProjectionType = [](LSHProjectionType type) {
+ switch (type) {
+ case LSHProjectionType_SPARSE:
+ return kTfLiteLshProjectionSparse;
+ case LSHProjectionType_DENSE:
+ return kTfLiteLshProjectionDense;
+ default:
+ return kTfLiteLshProjectionUnknown;
+ }
+ };
+ auto parseCombinerType = [](CombinerType type) {
+ switch (type) {
+ case CombinerType_MEAN:
+ return kTfLiteCombinerTypeMean;
+ case CombinerType_SQRTN:
+ return kTfLiteCombinerTypeSqrtn;
+ case CombinerType_SUM:
+ default:
+ return kTfLiteCombinerTypeSum;
+ }
+ };
+
+ void* builtin_data = nullptr;
+ switch (op_type) {
+ case BuiltinOperator_CALL:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ break;
+ case BuiltinOperator_CUSTOM:
+ break;
+ case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU1:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ break;
+ case BuiltinOperator_LSH_PROJECTION: {
+ TfLiteLSHProjectionParams* params =
+ MallocPOD<TfLiteLSHProjectionParams>();
+ if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+ params->type = parseLSHProjectionType(lshParams->type());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_L2_POOL_2D: {
+ TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+ params->padding = parse_padding(pool_params->padding());
+ params->stride_width = pool_params->stride_w();
+ params->stride_height = pool_params->stride_h();
+ params->filter_width = pool_params->filter_width();
+ params->filter_height = pool_params->filter_height();
+ params->activation =
+ parse_activation(pool_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DEPTHWISE_CONV_2D: {
+ TfLiteDepthwiseConvParams* params =
+ MallocPOD<TfLiteDepthwiseConvParams>();
+ if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->depth_multiplier = conv_params->depth_multiplier();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SVDF: {
+ TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+ params->rank = svdf_params->rank();
+ params->activation =
+ parse_activation(svdf_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RNN: {
+ TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+ params->activation =
+ parse_activation(rnn_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+ TfLiteEmbeddingLookupSparseParams* params =
+ MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ if (auto* embedding_params =
+ op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+ params->combiner = parseCombinerType(embedding_params->combiner());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_FULLY_CONNECTED: {
+ TfLiteFullyConnectedParams* params =
+ MallocPOD<TfLiteFullyConnectedParams>();
+ if (auto* fully_connected_params =
+ op->builtin_options_as_FullyConnectedOptions()) {
+ params->activation = parse_activation(
+ fully_connected_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_HASHTABLE_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_SOFTMAX: {
+ TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+ params->beta = softmax_params->beta();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CONCATENATION: {
+ TfLiteConcatenationParams* params =
+ MallocPOD<TfLiteConcatenationParams>();
+ if (auto* concatenation_params =
+ op->builtin_options_as_ConcatenationOptions()) {
+ params->activation =
+ parse_activation(concatenation_params->fused_activation_function());
+ params->axis = concatenation_params->axis();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MUL: {
+ auto* params = MallocPOD<TfLiteMulParams>();
+ if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ADD: {
+ auto* params = MallocPOD<TfLiteAddParams>();
+ if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_L2_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteL2NormParams>();
+ if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_LocalResponseNormalizationOptions()) {
+ params->radius = schema_params->radius();
+ params->bias = schema_params->bias();
+ params->alpha = schema_params->alpha();
+ params->beta = schema_params->beta();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LSTM: {
+ TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ parse_activation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ params->new_height = schema_params->new_height();
+ params->new_width = schema_params->new_width();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESHAPE: {
+ auto* params = MallocPOD<TfLiteReshapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+ auto* new_shape = schema_params->new_shape();
+ if (!new_shape) {
+ error_reporter->Report("No new_shape provided for Reshape\n");
+ } else {
+ params->num_dimensions = new_shape->Length();
+ if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) {
+ error_reporter->Report(
+ "Found too many dimensions in Reshape's new_shape\n");
+ } else {
+ for (int i = 0; i < params->num_dimensions; ++i) {
+ params->shape[i] = new_shape->Get(i);
+ }
+ }
+ }
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SKIP_GRAM: {
+ TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+ params->ngram_size = skip_gram_params->ngram_size();
+ params->max_skip_size = skip_gram_params->max_skip_size();
+ params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPACE_TO_DEPTH: {
+ auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+ params->block_size = schema_params->block_size();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ }
+ return builtin_data;
+}
+
+} // namespace
+
+TfLiteStatus InterpreterBuilder::ParseNodes(
+ const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
+ Interpreter* interpreter) {
+ TfLiteStatus status = kTfLiteOk;
+ for (int i = 0; i < operators->Length(); ++i) {
+ const auto* op = operators->Get(i);
+ int index = op->opcode_index();
+ if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
+ error_reporter_->Report("Missing registration for opcode_index %d\n",
+ index);
+ status = kTfLiteError;
+ continue;
+ }
+ const TfLiteRegistration* reg =
+ flatbuffer_op_index_to_registration_[op->opcode_index()];
+ if (reg == nullptr) {
+ error_reporter_->Report("Skipping op for opcode_index %d\n", index);
+ status = kTfLiteError;
+ continue;
+ }
+
+ auto op_type =
+ flatbuffer_op_index_to_registration_types_[op->opcode_index()];
+ if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
+ error_reporter_->Report(
+ "Found builtin operator %s with custom options.\n",
+ EnumNameBuiltinOperator(op_type));
+ }
+ if (op->custom_options()) {
+ interpreter->AddNodeWithParameters(
+ FlatBufferIntArrayToVector(op->inputs()),
+ FlatBufferIntArrayToVector(op->outputs()),
+ reinterpret_cast<const char*>(op->custom_options()->data()),
+ op->custom_options()->size(), nullptr, reg);
+ } else {
+ interpreter->AddNodeWithParameters(
+ FlatBufferIntArrayToVector(op->inputs()),
+ FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
+ ParseOpData(op, op_type, error_reporter_), reg);
+ }
+ }
+
+ return status;
+}
+
+TfLiteStatus InterpreterBuilder::ParseTensors(
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
+ Interpreter* interpreter) {
+ TfLiteStatus status = kTfLiteOk;
+
+ // A little helper to get the names of inputs and outputs. Note that they
+ // must outlive the interpreter.
+ auto get_name = [](const tflite::Tensor* t) -> const char* {
+ auto name = t->name();
+ if (name) return name->c_str();
+ return kEmptyTensorName;
+ };
+
+ for (int i = 0; i < tensors->Length(); ++i) {
+ const auto* tensor = tensors->Get(i);
+ std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
+
+ TfLiteQuantizationParams quantization;
+ quantization.scale = 0;
+ quantization.zero_point = 0;
+ auto* q_params = tensor->quantization();
+ if (q_params) {
+ // Note that the schema could hold per-channel quantization parameters
+ // but we really only support one value for the whole tensor.
+ // TODO(aselle): This breaks as well if these are nullptr's.
+ // TODO(aselle): This assumes non per-channel quantization.
+ if (q_params->scale()) quantization.scale = q_params->scale()->Get(0);
+ if (q_params->zero_point())
+ quantization.zero_point = q_params->zero_point()->Get(0);
+ }
+
+ TfLiteType type;
+ switch (tensor->type()) {
+ case TensorType_FLOAT32:
+ type = kTfLiteFloat32;
+ break;
+ case TensorType_INT32:
+ type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ type = kTfLiteString;
+ break;
+ default:
+ // tensorType = ArrayType::NONE;
+ error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor->type()),
+ tensor->type());
+ status = kTfLiteError;
+ continue;
+ }
+ auto get_readonly_data = [&](const char** buffer_data,
+ size_t* buffer_size) {
+ // TODO(aselle): Check what happens if we have an unspecified size
+ // constant.
+ *buffer_data = nullptr;
+ if (tensor->buffer() == 0) return kTfLiteOk;
+ if (tensor->buffer() >= buffers->size()) {
+ error_reporter_->Report(
+ "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
+ i, tensor->buffer(), buffers->size());
+ return kTfLiteError;
+ }
+ if (auto* buffer = (*buffers)[tensor->buffer()]) {
+ if (auto* array = buffer->data()) {
+ if (size_t size = array->size()) {
+ *buffer_size = size;
+ *buffer_data = reinterpret_cast<const char*>(array->data());
+ return kTfLiteOk;
+ }
+ }
+ }
+ return kTfLiteOk;
+ };
+ size_t buffer_size = 0;
+ const char* buffer_ptr;
+ TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
+
+ if (buffer_ptr) {
+ if (interpreter->SetTensorParametersReadOnly(
+ i, type, get_name(tensor), dims, quantization, buffer_ptr,
+ buffer_size, allocation_) != kTfLiteOk) {
+ error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
+ i);
+ status = kTfLiteError;
+ }
+ } else {
+ if (interpreter->SetTensorParametersReadWrite(
+ i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
+ error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
+ i);
+ status = kTfLiteError;
+ }
+ }
+ }
+
+ return status;
+}
+
+TfLiteStatus InterpreterBuilder::operator()(
+ std::unique_ptr<Interpreter>* interpreter) {
+ if (!interpreter) {
+ error_reporter_->Report(
+ "Null output pointer passed to InterpreterBuilder.");
+ return kTfLiteError;
+ }
+
+ // Safe exit by deleting partially created interpreter, to reduce verbosity
+ // on error conditions. Use by return cleanup_on_error();
+ auto cleanup_and_error = [&interpreter]() {
+ interpreter->reset();
+ return kTfLiteError;
+ };
+
+ if (!model_) {
+ error_reporter_->Report("Null pointer passed in as model.");
+ return cleanup_and_error();
+ }
+
+ if (model_->version() != TFLITE_SCHEMA_VERSION) {
+ error_reporter_->Report(
+ "Model provided is schema version %d not equal "
+ "to supported version %d.\n",
+ model_->version(), TFLITE_SCHEMA_VERSION);
+ return cleanup_and_error();
+ }
+
+ if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
+ error_reporter_->Report("Registration failed.\n");
+ return cleanup_and_error();
+ }
+
+ // Flatbuffer model schemas define a list of opcodes independent of the graph.
+ // We first map those to registrations. This reduces string lookups for custom
+ // ops since we only do it once per custom op rather than once per custom op
+ // invocation in the model graph.
+ // Construct interpreter with correct number of tensors and operators.
+ auto* subgraphs = model_->subgraphs();
+ auto* buffers = model_->buffers();
+ if (subgraphs->size() != 1) {
+ error_reporter_->Report("Only 1 subgraph is currently supported.\n");
+ return cleanup_and_error();
+ }
+ const tflite::SubGraph* subgraph = (*subgraphs)[0];
+ auto operators = subgraph->operators();
+ auto tensors = subgraph->tensors();
+ if (!operators || !tensors || !buffers) {
+ error_reporter_->Report(
+ "Did not get operators, tensors, or buffers in input flat buffer.\n");
+ return cleanup_and_error();
+ }
+ interpreter->reset(new Interpreter(error_reporter_));
+ if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
+ return cleanup_and_error();
+ }
+
+ // Parse inputs/outputs
+ (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
+ (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
+
+ // Finally setup nodes and tensors
+ if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
+ return cleanup_and_error();
+ if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
+ return cleanup_and_error();
+
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
new file mode 100644
index 0000000000..15659d33f3
--- /dev/null
+++ b/tensorflow/contrib/lite/model.h
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Deserialization infrastructure for tflite. Provides functionality
+// to go from a serialized tflite model in flatbuffer format to an
+// interpreter.
+//
+// using namespace tflite;
+// StderrReporter error_reporter;
+// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
+// &error_reporter);
+// MyOpResolver resolver; // You need to subclass OpResolver to provide
+// // implementations.
+// InterpreterBuilder builder(*model, resolver);
+// std::unique_ptr<Interpreter> interpreter;
+// if(builder(&interpreter) == kTfLiteOk) {
+// .. run model inference with interpreter
+// }
+//
+// OpResolver must be defined to provide your kernel implementations to the
+// interpreter. This is environment specific and may consist of just the builtin
+// ops, or some custom operators you defined to extend tflite.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
+
+#include <memory>
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// An RAII object that represents a read-only tflite model, copied from disk,
+// or mmapped. This uses flatbuffers as the serialization format.
+class FlatBufferModel {
+ public:
+ // Build a model based on a file. Return a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> BuildFromFile(
+ const char* filename,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ // Build a model based on a pre-loaded flatbuffer. The caller retains
+ // ownership of the buffer and should keep it alive until the returned object
+ // is destroyed. Return a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
+ const char* buffer, size_t buffer_size,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ // Releases memory or unmaps mmaped meory.
+ ~FlatBufferModel();
+
+ // Copying or assignment is disallowed to simplify ownership semantics.
+ FlatBufferModel(const FlatBufferModel&) = delete;
+ FlatBufferModel& operator=(const FlatBufferModel&) = delete;
+
+ bool initialized() const { return model_ != nullptr; }
+ const tflite::Model* operator->() const { return model_; }
+ const tflite::Model* GetModel() const { return model_; }
+ ErrorReporter* error_reporter() const { return error_reporter_; }
+ const Allocation* allocation() const { return allocation_; }
+
+ // Returns true if the model identifier is correct (otherwise false and
+ // reports an error).
+ bool CheckModelIdentifier() const;
+
+ private:
+ // Load a model from `filename`. If `mmap_file` is true then use mmap,
+ // otherwise make a copy of the model in a buffer.
+ //
+ // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
+ // used.
+ explicit FlatBufferModel(
+ const char* filename, bool mmap_file = true,
+ ErrorReporter* error_reporter = DefaultErrorReporter(),
+ bool use_nnapi = false);
+
+ // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to
+ // remain alive and unchanged until the end of this flatbuffermodel's
+ // lifetime.
+ //
+ // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
+ // used.
+ FlatBufferModel(const char* ptr, size_t num_bytes,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
+ // Flatbuffer traverser pointer. (Model* is a pointer that is within the
+ // allocated memory of the data allocated by allocation's internals.
+ const tflite::Model* model_ = nullptr;
+ ErrorReporter* error_reporter_;
+ Allocation* allocation_ = nullptr;
+};
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Find the op registration for a builtin operator by enum code.
+ virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
+ // Find the op registration of a custom operator by op name.
+ virtual TfLiteRegistration* FindOp(const char* op) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Build an interpreter capable of interpreting `model`.
+//
+// model: a scoped model whose lifetime must be at least as long as
+// the interpreter. In principle multiple interpreters can be made from
+// a single model.
+// op_resolver: An instance that implements the Resolver interface which maps
+// custom op names and builtin op codes to op registrations.
+// reportError: a functor that is called to report errors that handles
+// printf var arg semantics. The lifetime of the reportError object must
+// be greater than or equal to the Interpreter created by operator().
+//
+// Returns a kTfLiteOk when successful and sets interpreter to a valid
+// Interpreter. Note: the user must ensure the model lifetime is at least as
+// long as interpreter's lifetime.
+class InterpreterBuilder {
+ public:
+ InterpreterBuilder(const FlatBufferModel& model,
+ const OpResolver& op_resolver);
+ // Build an interpreter given only the raw flatbuffer Model object (instead
+ // of a FlatBufferModel). Mostly used for testing.
+ // If `error_reporter` is null, then DefaultErrorReporter() is used.
+ InterpreterBuilder(const ::tflite::Model* model,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+ InterpreterBuilder(const InterpreterBuilder&) = delete;
+ InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
+ TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
+
+ private:
+ TfLiteStatus BuildLocalIndexToRegistrationMapping();
+ TfLiteStatus ParseNodes(
+ const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
+ Interpreter* interpreter);
+ TfLiteStatus ParseTensors(
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
+ Interpreter* interpreter);
+
+ const ::tflite::Model* model_;
+ const OpResolver& op_resolver_;
+ ErrorReporter* error_reporter_;
+
+ std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_;
+ std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
+ const Allocation* allocation_ = nullptr;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
new file mode 100644
index 0000000000..ae823650d6
--- /dev/null
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -0,0 +1,258 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <fcntl.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "tensorflow/contrib/lite/model.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+
+// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
+// we must declare this in global namespace, so argument-dependent operator
+// lookup works.
+inline bool operator==(const TfLiteRegistration& a,
+ const TfLiteRegistration& b) {
+ return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare &&
+ a.free == b.free;
+}
+
+namespace tflite {
+
+// Provide a dummy operation that does nothing.
+namespace {
+void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; }
+void dummy_free(TfLiteContext*, void*) {}
+TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
+TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
+TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize,
+ dummy_invoke};
+} // namespace
+
+// Provide a trivial resolver that returns a constant value no matter what
+// op is asked for.
+class TrivialResolver : public OpResolver {
+ public:
+ explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr)
+ : constant_return_(constant_return) {}
+ // Find the op registration of a custom operator by op name.
+ TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
+ return constant_return_;
+ }
+ // Find the op registration of a custom operator by op name.
+ TfLiteRegistration* FindOp(const char* op) const override {
+ return constant_return_;
+ }
+
+ private:
+ TfLiteRegistration* constant_return_;
+};
+
+TEST(BasicFlatBufferModel, TestNonExistantFiles) {
+ ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234"));
+}
+
+// Make sure a model with nothing in it loads properly.
+TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/empty_model.bin");
+ ASSERT_TRUE(model);
+ // Now try to build it into a model.
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter),
+ kTfLiteOk);
+ ASSERT_NE(interpreter, nullptr);
+ ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk);
+}
+
+// Make sure currently unsupported # of subgraphs are checked
+// TODO(aselle): Replace this test when multiple subgraphs are supported.
+TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) {
+ auto m1 = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/0_subgraphs.bin");
+ ASSERT_TRUE(m1);
+ std::unique_ptr<Interpreter> interpreter1;
+ ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1),
+ kTfLiteOk);
+
+ auto m2 = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/2_subgraphs.bin");
+ ASSERT_TRUE(m2);
+ std::unique_ptr<Interpreter> interpreter2;
+ ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2),
+ kTfLiteOk);
+}
+
+// Test what happens if we cannot bind any of the ops.
+TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin");
+ ASSERT_TRUE(model);
+ // Check that we get an error code and interpreter pointer is reset.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter, nullptr);
+}
+
+// Make sure model is read to interpreter propelrly
+TEST(BasicFlatBufferModel, TestModelInInterpreter) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin");
+ ASSERT_TRUE(model);
+ // Check that we get an error code and interpreter pointer is reset.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ ASSERT_EQ(
+ InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
+ kTfLiteOk);
+ ASSERT_NE(interpreter, nullptr);
+ ASSERT_EQ(interpreter->tensors_size(), 4);
+ ASSERT_EQ(interpreter->nodes_size(), 2);
+ std::vector<int> inputs = {0, 1};
+ std::vector<int> outputs = {2, 3};
+ ASSERT_EQ(interpreter->inputs(), inputs);
+ ASSERT_EQ(interpreter->outputs(), outputs);
+
+ EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0");
+ EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1");
+ EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1");
+ EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2");
+
+ // Make sure all input tensors are correct
+ TfLiteTensor* i0 = interpreter->tensor(0);
+ ASSERT_EQ(i0->type, kTfLiteFloat32);
+ ASSERT_NE(i0->data.raw, nullptr); // mmapped
+ ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo);
+ TfLiteTensor* i1 = interpreter->tensor(1);
+ ASSERT_EQ(i1->type, kTfLiteFloat32);
+ ASSERT_EQ(i1->data.raw, nullptr);
+ ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw);
+ TfLiteTensor* o0 = interpreter->tensor(2);
+ ASSERT_EQ(o0->type, kTfLiteFloat32);
+ ASSERT_EQ(o0->data.raw, nullptr);
+ ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw);
+ TfLiteTensor* o1 = interpreter->tensor(3);
+ ASSERT_EQ(o1->type, kTfLiteFloat32);
+ ASSERT_EQ(o1->data.raw, nullptr);
+ ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw);
+
+ // Check op 0 which has inputs {0, 1} outputs {2}.
+ {
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg0 =
+ interpreter->node_and_registration(0);
+ ASSERT_NE(node_and_reg0, nullptr);
+ const TfLiteNode& node0 = node_and_reg0->first;
+ const TfLiteRegistration& reg0 = node_and_reg0->second;
+ TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2);
+ desired_inputs->data[0] = 0;
+ desired_inputs->data[1] = 1;
+ TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
+ desired_outputs->data[0] = 2;
+ ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs));
+ ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs));
+ TfLiteIntArrayFree(desired_inputs);
+ TfLiteIntArrayFree(desired_outputs);
+ ASSERT_EQ(reg0, dummy_reg);
+ }
+
+ // Check op 1 which has inputs {2} outputs {3}.
+ {
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg1 =
+ interpreter->node_and_registration(1);
+ ASSERT_NE(node_and_reg1, nullptr);
+ const TfLiteNode& node1 = node_and_reg1->first;
+ const TfLiteRegistration& reg1 = node_and_reg1->second;
+ TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1);
+ TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
+ desired_inputs->data[0] = 2;
+ desired_outputs->data[0] = 3;
+ ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs));
+ ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs));
+ TfLiteIntArrayFree(desired_inputs);
+ TfLiteIntArrayFree(desired_outputs);
+ ASSERT_EQ(reg1, dummy_reg);
+ }
+}
+
+// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
+// buffer. But the buffer is provided to be only 1 element.
+TEST(BasicFlatBufferModel, TestBrokenMmap) {
+ ASSERT_FALSE(FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model_broken.bin"));
+}
+
+TEST(BasicFlatBufferModel, TestNullModel) {
+ // Check that we get an error code and interpreter pointer is reset.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ ASSERT_NE(
+ InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.get(), nullptr);
+}
+
+struct TestErrorReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override {
+ calls++;
+ return 0;
+ }
+ int calls = 0;
+};
+
+// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
+// the Interpreter.
+TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
+ TestErrorReporter reporter;
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/empty_model.bin",
+ &reporter);
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ TrivialResolver resolver;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
+ ASSERT_EQ(reporter.calls, 1);
+}
+
+// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
+// the Interpreter.
+TEST(BasicFlatBufferModel, TestNullErrorReporter) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr);
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ TrivialResolver resolver;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
+}
+
+// TODO(aselle): Add tests for serialization of builtin op data types.
+// These tests will occur with the evaluation tests of individual operators,
+// not here.
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD
new file mode 100644
index 0000000000..fbdf19f205
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/BUILD
@@ -0,0 +1,15 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
new file mode 100644
index 0000000000..1c422b659a
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Convert a list of strings to integers via hashing.
+// Input:
+// Input[0]: A list of ngrams. string[num of input]
+//
+// Output:
+// Output[0]: Hashed features. int32[num of input]
+// Output[1]: Weights. float[num of input]
+
+#include <algorithm>
+#include <map>
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include <farmhash.h>
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace extract {
+
+static const int kMaxDimension = 1000000;
+static const std::vector<string> kBlacklistNgram = {"<S>", "<E>", "<S> <E>"};
+
+bool Equals(const string& x, const tflite::StringRef& strref) {
+ if (strref.len != x.length()) {
+ return false;
+ }
+ if (strref.len > 0) {
+ int r = memcmp(strref.str, x.data(), strref.len);
+ return r == 0;
+ }
+ return true;
+}
+
+bool IsValidNgram(const tflite::StringRef& strref) {
+ for (const auto& s : kBlacklistNgram) {
+ if (Equals(s, strref)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1);
+ TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ int dim = input->dims->data[0];
+ if (dim == 0) {
+ // TFLite non-string output should have size greater than 0.
+ dim = 1;
+ }
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString);
+ outputSize1->data[0] = dim;
+ outputSize2->data[0] = dim;
+ context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1);
+ context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ int num_strings = tflite::GetStringCount(input);
+ TfLiteTensor* label = GetOutput(context, node, 0);
+ TfLiteTensor* weight = GetOutput(context, node, 1);
+
+ std::map<int64, int> feature_id_counts;
+ for (int i = 0; i < num_strings; i++) {
+ // Use fingerprint of feature name as id.
+ auto strref = tflite::GetString(input, i);
+ if (!IsValidNgram(strref)) {
+ label->data.i32[i] = 0;
+ weight->data.i32[i] = 0;
+ continue;
+ }
+
+ int64 feature_id =
+ ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
+
+ label->data.i32[i] = static_cast<int32>(feature_id);
+ weight->data.f[i] =
+ std::count(strref.str, strref.str + strref.len, ' ') + 1;
+ }
+ // Explicitly set an empty result to make preceding ops run.
+ if (num_strings == 0) {
+ label->data.i32[0] = 0;
+ weight->data.i32[0] = 0;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace extract
+
+TfLiteRegistration* Register_EXTRACT_FEATURES() {
+ static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare,
+ extract::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
new file mode 100644
index 0000000000..9b8676bab6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include <farmhash.h>
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_EXTRACT_FEATURES();
+
+namespace {
+
+using ::testing::ElementsAre;
+
+class ExtractFeatureOpModel : public SingleOpModel {
+ public:
+ explicit ExtractFeatureOpModel(const std::vector<string>& input) {
+ input_ = AddInput(TensorType_STRING);
+ signature_ = AddOutput(TensorType_INT32);
+ weight_ = AddOutput(TensorType_FLOAT32);
+
+ SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, input);
+ }
+
+ std::vector<int> GetSignature() { return ExtractVector<int>(signature_); }
+ std::vector<float> GetWeight() { return ExtractVector<float>(weight_); }
+
+ private:
+ int input_;
+ int signature_;
+ int weight_;
+};
+
+int CalcFeature(const string& str) {
+ return ::util::Fingerprint64(str) % 1000000;
+}
+
+TEST(ExtractFeatureOpTest, RegularInput) {
+ ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(),
+ ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"),
+ CalcFeature("Hi !"), CalcFeature("!"),
+ CalcFeature("! <E>"), 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0));
+}
+
+TEST(ExtractFeatureOpTest, OneInput) {
+ ExtractFeatureOpModel m({"Hi"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi")));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(1));
+}
+
+TEST(ExtractFeatureOpTest, ZeroInput) {
+ ExtractFeatureOpModel m({});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0));
+}
+
+TEST(ExtractFeatureOpTest, AllBlacklistInput) {
+ ExtractFeatureOpModel m({"<S>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
new file mode 100644
index 0000000000..d0dc2a35a7
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Normalize the string input.
+//
+// Input:
+// Input[0]: One sentence. string[1]
+//
+// Output:
+// Output[0]: Normalized sentence. string[1]
+//
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/strip.h"
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace normalize {
+
+// Predictor transforms.
+const char kPunctuationsRegex[] = "[.*()\"]";
+
+const std::map<string, string>* kRegexTransforms =
+ new std::map<string, string>({
+ {"([^\\s]+)n't", "\\1 not"},
+ {"([^\\s]+)'nt", "\\1 not"},
+ {"([^\\s]+)'ll", "\\1 will"},
+ {"([^\\s]+)'re", "\\1 are"},
+ {"([^\\s]+)'ve", "\\1 have"},
+ {"i'm", "i am"},
+ });
+
+static const char kStartToken[] = "<S>";
+static const char kEndToken[] = "<E>";
+static const int32 kMaxInputChars = 300;
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
+
+ string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len)));
+ absl::StripAsciiWhitespace(&result);
+ // Do not remove commas, semi-colons or colons from the sentences as they can
+ // indicate the beginning of a new clause.
+ RE2::GlobalReplace(&result, kPunctuationsRegex, "");
+ RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])",
+ "\\1\\2");
+ RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1");
+ for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end();
+ iter++) {
+ RE2::GlobalReplace(&result, iter->first, iter->second);
+ }
+
+ // Treat questions & interjections as special cases.
+ RE2::GlobalReplace(&result, "([?])+", "\\1");
+ RE2::GlobalReplace(&result, "([!])+", "\\1");
+ RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 ");
+ RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2");
+
+ RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", "");
+ RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", "");
+ absl::StripAsciiWhitespace(&result);
+
+ // Add start and end token.
+ // Truncate input to maximum allowed size.
+ if (result.length() <= kMaxInputChars) {
+ absl::StrAppend(&result, " ", kEndToken);
+ } else {
+ result = result.substr(0, kMaxInputChars);
+ }
+ result = absl::StrCat(kStartToken, " ", result);
+
+ tflite::DynamicBuffer buf;
+ buf.AddString(result.data(), result.length());
+ buf.WriteToTensor(GetOutput(context, node, 0));
+ return kTfLiteOk;
+}
+
+} // namespace normalize
+
+TfLiteRegistration* Register_NORMALIZE() {
+ static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc
new file mode 100644
index 0000000000..4d35dba9a6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_NORMALIZE();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class NormalizeOpModel : public SingleOpModel {
+ public:
+ explicit NormalizeOpModel(const string& input) {
+ input_ = AddInput(TensorType_STRING);
+ output_ = AddOutput(TensorType_STRING);
+
+ SetCustomOp("Normalize", {}, Register_NORMALIZE);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, {input});
+ }
+
+ std::vector<string> GetStringOutput() {
+ TfLiteTensor* output = interpreter_->tensor(output_);
+ int num = GetStringCount(output);
+ std::vector<string> result(num);
+ for (int i = 0; i < num; i++) {
+ auto ref = GetString(output, i);
+ result[i] = string(ref.str, ref.len);
+ }
+ return result;
+ }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(NormalizeOpTest, RegularInput) {
+ NormalizeOpModel m("I'm good; you're welcome");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(),
+ ElementsAreArray({"<S> i am good; you are welcome <E>"}));
+}
+
+TEST(NormalizeOpTest, OneInput) {
+ NormalizeOpModel m("Hi!!!!");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> hi ! <E>"}));
+}
+
+TEST(NormalizeOpTest, EmptyInput) {
+ NormalizeOpModel m("");
+ m.Invoke();
+ EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> <E>"}));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc
new file mode 100644
index 0000000000..7b23adb990
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Lookup projected hash signatures in Predictor model,
+// output predicted labels and weights in decreasing order.
+//
+// Input:
+// Input[0]: A list of hash signatures. int32[num of input]
+// Input[1]: Hash signature keys in the model. int32[keys of model]
+// Input[2]: Labels in the model. int32[keys of model, item per entry]
+// Input[3]: Weights in the model. float[keys of model, item per entry]
+//
+// Output:
+// Output[0]: Predicted labels. int32[num of output]
+// Output[1]: Predicted weights. float[num of output]
+//
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+namespace predict {
+
+struct PredictOption {
+ int32_t num_output;
+ float weight_threshold;
+
+ static PredictOption* Cast(void* ptr) {
+ return reinterpret_cast<PredictOption*>(ptr);
+ }
+};
+
+bool WeightGreater(const std::pair<int32_t, float>& a,
+ const std::pair<int32_t, float>& b) {
+ return a.second > b.second;
+}
+
+void* Init(TfLiteContext* context, const char* custom_option, size_t length) {
+ if (custom_option == nullptr || length != sizeof(PredictOption)) {
+ fprintf(stderr, "No Custom option set\n");
+ exit(1);
+ }
+ PredictOption* option = new PredictOption;
+ int offset = 0;
+ option->num_output =
+ *reinterpret_cast<const int32_t*>(custom_option + offset);
+ offset += sizeof(int32_t);
+ option->weight_threshold =
+ *reinterpret_cast<const float*>(custom_option + offset);
+ return reinterpret_cast<void*>(option);
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete PredictOption::Cast(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
+ TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
+ model_label->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
+ model_weight->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, model_label->dims->data[1],
+ model_weight->dims->data[1]);
+
+ PredictOption* option = PredictOption::Cast(node->user_data);
+ TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
+ TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32);
+
+ TfLiteIntArray* label_size = TfLiteIntArrayCreate(1);
+ label_size->data[0] = option->num_output;
+ TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1);
+ weight_size->data[0] = option->num_output;
+ TfLiteStatus status =
+ context->ResizeTensor(context, output_label, label_size);
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ return context->ResizeTensor(context, output_weight, weight_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
+ TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
+
+ // Aggregate by key
+ std::unordered_map<int32_t, float> aggregation;
+ const int num_input = lookup->dims->data[0];
+ const int num_rows = model_key->dims->data[0];
+ const int items = model_label->dims->data[1];
+ int* model_key_end = model_key->data.i32 + num_rows;
+
+ for (int i = 0; i < num_input; i++) {
+ int* ptr = std::lower_bound(model_key->data.i32, model_key_end,
+ lookup->data.i32[i]);
+ if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) {
+ int idx = ptr - model_key->data.i32;
+ for (int j = 0; j < items; j++) {
+ aggregation[model_label->data.i32[idx * items + j]] +=
+ model_weight->data.f[idx * items + j] / num_input;
+ }
+ }
+ }
+
+ // Sort by value
+ std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(),
+ aggregation.end());
+ std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater);
+
+ PredictOption* option = PredictOption::Cast(node->user_data);
+ TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
+ for (int i = 0; i < output_label->dims->data[0]; i++) {
+ if (i >= sorted_labels.size() ||
+ sorted_labels[i].second < option->weight_threshold) {
+ // Set -1 to avoid lookup message with id 0, which is set for backoff.
+ output_label->data.i32[i] = -1;
+ output_weight->data.f[i] = 0.0f;
+ } else {
+ output_label->data.i32[i] = sorted_labels[i].first;
+ output_weight->data.f[i] = sorted_labels[i].second;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace predict
+
+TfLiteRegistration* Register_PREDICT() {
+ static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare,
+ predict::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc
new file mode 100644
index 0000000000..e97c58cbd1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_PREDICT();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class PredictOpModel : public SingleOpModel {
+ public:
+ PredictOpModel(std::initializer_list<int> input_signature_shape,
+ std::initializer_list<int> key_shape,
+ std::initializer_list<int> labelweight_shape, int num_output,
+ float threshold) {
+ input_signature_ = AddInput(TensorType_INT32);
+ model_key_ = AddInput(TensorType_INT32);
+ model_label_ = AddInput(TensorType_INT32);
+ model_weight_ = AddInput(TensorType_FLOAT32);
+ output_label_ = AddOutput(TensorType_INT32);
+ output_weight_ = AddOutput(TensorType_FLOAT32);
+
+ std::vector<uint8_t> predict_option;
+ writeInt32(num_output, &predict_option);
+ writeFloat32(threshold, &predict_option);
+ SetCustomOp("Predict", predict_option, Register_PREDICT);
+ BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
+ labelweight_shape}});
+ }
+
+ void SetInputSignature(std::initializer_list<int> data) {
+ PopulateTensor<int>(input_signature_, data);
+ }
+
+ void SetModelKey(std::initializer_list<int> data) {
+ PopulateTensor<int>(model_key_, data);
+ }
+
+ void SetModelLabel(std::initializer_list<int> data) {
+ PopulateTensor<int>(model_label_, data);
+ }
+
+ void SetModelWeight(std::initializer_list<float> data) {
+ PopulateTensor<float>(model_weight_, data);
+ }
+
+ std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
+ std::vector<float> GetWeight() {
+ return ExtractVector<float>(output_weight_);
+ }
+
+ void writeFloat32(float value, std::vector<uint8_t>* data) {
+ union {
+ float v;
+ uint8_t r[4];
+ } float_to_raw;
+ float_to_raw.v = value;
+ for (unsigned char i : float_to_raw.r) {
+ data->push_back(i);
+ }
+ }
+
+ void writeInt32(int32_t value, std::vector<uint8_t>* data) {
+ union {
+ int32_t v;
+ uint8_t r[4];
+ } int32_to_raw;
+ int32_to_raw.v = value;
+ for (unsigned char i : int32_to_raw.r) {
+ data->push_back(i);
+ }
+ }
+
+ private:
+ int input_signature_;
+ int model_key_;
+ int model_label_;
+ int model_weight_;
+ int output_label_;
+ int output_weight_;
+};
+
+TEST(PredictOpTest, AllLabelsAreValid) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
+}
+
+TEST(PredictOpTest, MoreLabelsThanRequired) {
+ PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
+}
+
+TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
+}
+
+TEST(PredictOpTest, NoneLabelPassThreshold) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
+ m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
+}
+
+TEST(PredictOpTest, OnlyOneLabelGenerated) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({1, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
+ m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
+}
+
+TEST(PredictOpTest, NoLabelGenerated) {
+ PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
+ m.SetInputSignature({5, 3, 7, 9});
+ m.SetModelKey({1, 2, 4, 6, 7});
+ m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
+ m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
+ EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc
new file mode 100644
index 0000000000..a28222213e
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+#include "absl/strings/str_split.h"
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+
+// Split sentence into segments (using punctuation).
+std::vector<string> SplitSentence(const string& input) {
+ string result(input);
+
+ RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
+ RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
+ RE2::GlobalReplace(&result, "[ ]+", " ");
+ RE2::GlobalReplace(&result, "\t+$", "");
+
+ return strings::Split(result, '\t');
+}
+
+// Predict with TfLite model.
+void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
+ std::map<string, float>* response_map) {
+ {
+ TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
+ tflite::DynamicBuffer buf;
+ buf.AddString(sentence.data(), sentence.length());
+ buf.WriteToTensor(input);
+ interpreter->AllocateTensors();
+
+ interpreter->Invoke();
+
+ TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
+ TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
+
+ for (int i = 0; i < confidence->dims->data[0]; i++) {
+ float weight = confidence->data.f[i];
+ auto response_text = tflite::GetString(messages, i);
+ if (response_text.len > 0) {
+ (*response_map)[string(response_text.str, response_text.len)] += weight;
+ }
+ }
+ }
+}
+
+void GetSegmentPredictions(
+ const std::vector<string>& input, const ::tflite::FlatBufferModel& model,
+ const SmartReplyConfig& config,
+ std::vector<PredictorResponse>* predictor_responses) {
+ // Initialize interpreter
+ std::unique_ptr<::tflite::Interpreter> interpreter;
+ ::tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+ ::tflite::InterpreterBuilder(model, resolver)(&interpreter);
+
+ if (!model.initialized()) {
+ fprintf(stderr, "Failed to mmap model \n");
+ return;
+ }
+
+ // Execute Tflite Model
+ std::map<string, float> response_map;
+ std::vector<string> sentences;
+ for (const string& str : input) {
+ std::vector<string> splitted_str = SplitSentence(str);
+ sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
+ }
+ for (const auto& sentence : sentences) {
+ ExecuteTfLite(sentence, interpreter.get(), &response_map);
+ }
+
+ // Generate the result.
+ for (const auto& iter : response_map) {
+ PredictorResponse prediction(iter.first, iter.second);
+ predictor_responses->emplace_back(prediction);
+ }
+ std::sort(predictor_responses->begin(), predictor_responses->end(),
+ [](const PredictorResponse& a, const PredictorResponse& b) {
+ return a.GetScore() > b.GetScore();
+ });
+
+ // Add backoff response.
+ for (const string& backoff : config.backoff_responses) {
+ if (predictor_responses->size() >= config.num_response) {
+ break;
+ }
+ predictor_responses->push_back({backoff, config.backoff_confidence});
+ }
+}
+
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h
new file mode 100644
index 0000000000..3b9a2b32e1
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+
+const int kDefaultNumResponse = 10;
+const float kDefaultBackoffConfidence = 1e-4;
+
+class PredictorResponse;
+struct SmartReplyConfig;
+
+// With a given string as input, predict the response with a Tflite model.
+// When config.backoff_response is not empty, predictor_responses will be filled
+// with messagees from backoff response.
+void GetSegmentPredictions(const std::vector<string>& input,
+ const ::tflite::FlatBufferModel& model,
+ const SmartReplyConfig& config,
+ std::vector<PredictorResponse>* predictor_responses);
+
+// Data object used to hold a single predictor response.
+// It includes messages, and confidence.
+class PredictorResponse {
+ public:
+ PredictorResponse(const string& response_text, float score) {
+ response_text_ = response_text;
+ prediction_score_ = score;
+ }
+
+ // Accessor methods.
+ const string& GetText() const { return response_text_; }
+ float GetScore() const { return prediction_score_; }
+
+ private:
+ string response_text_ = "";
+ float prediction_score_ = 0.0;
+};
+
+// Configurations for SmartReply.
+struct SmartReplyConfig {
+ // Maximum responses to return.
+ int num_response;
+ // Default confidence for backoff responses.
+ float backoff_confidence;
+ // Backoff responses are used when predicted responses cannot fulfill the
+ // list.
+ const std::vector<string>& backoff_responses;
+
+ SmartReplyConfig(std::vector<string> backoff_responses)
+ : num_response(kDefaultNumResponse),
+ backoff_confidence(kDefaultBackoffConfidence),
+ backoff_responses(backoff_responses) {}
+};
+
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
new file mode 100644
index 0000000000..2fa9923bc9
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
+
+#include <fstream>
+#include <unordered_set>
+
+#include "base/logging.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace custom {
+namespace smartreply {
+namespace {
+
+const char kModelName[] = "smartreply_ondevice_model.bin";
+const char kSamples[] = "smartreply_samples.tsv";
+
+MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
+ bool has_expected_response = false;
+ for (const auto &item : *arg) {
+ const string &response = item.GetText();
+ if (expected_response.find(response) != expected_response.end()) {
+ has_expected_response = true;
+ break;
+ }
+ }
+ return has_expected_response;
+}
+
+class PredictorTest : public ::testing::Test {
+ protected:
+ PredictorTest() {
+ model_ = tflite::FlatBufferModel::BuildFromFile(
+ StrCat(TestDataPath(), "/", kModelName).c_str());
+ CHECK(model_);
+ }
+ ~PredictorTest() override {}
+
+ std::unique_ptr<::tflite::FlatBufferModel> model_;
+};
+
+TEST_F(PredictorTest, GetSegmentPredictions) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions);
+ EXPECT_GT(predictions.size(), 0);
+
+ float max = 0;
+ for (const auto &item : predictions) {
+ LOG(INFO) << "Response: " << item.GetText();
+ if (item.GetScore() > max) {
+ max = item.GetScore();
+ }
+ }
+
+ EXPECT_GT(max, 0.3);
+ EXPECT_THAT(
+ &predictions,
+ IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"})));
+}
+
+TEST_F(PredictorTest, TestTwoSentences) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}},
+ &predictions);
+ EXPECT_GT(predictions.size(), 0);
+
+ float max = 0;
+ for (const auto &item : predictions) {
+ LOG(INFO) << "Response: " << item.GetText();
+ if (item.GetScore() > max) {
+ max = item.GetScore();
+ }
+ }
+
+ EXPECT_GT(max, 0.3);
+ EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
+ {"Hi, how are you doing?"})));
+}
+
+TEST_F(PredictorTest, TestBackoff) {
+ std::vector<PredictorResponse> predictions;
+
+ GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions);
+ EXPECT_EQ(predictions.size(), 0);
+
+ // Backoff responses are returned in order.
+ GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}},
+ &predictions);
+ EXPECT_EQ(predictions.size(), 2);
+ EXPECT_EQ(predictions[0].GetText(), "Yes");
+ EXPECT_EQ(predictions[1].GetText(), "Ok");
+}
+
+TEST_F(PredictorTest, BatchTest) {
+ int total_items = 0;
+ int total_responses = 0;
+ int total_triggers = 0;
+
+ string line;
+ std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
+ while (std::getline(fin, line)) {
+ const std::vector<string> &fields = strings::Split(line, '\t');
+ if (fields.empty()) {
+ continue;
+ }
+
+ // Parse sample file and predict
+ const string &msg = fields[0];
+ std::vector<PredictorResponse> predictions;
+ GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions);
+
+ // Validate response and generate stats.
+ total_items++;
+ total_responses += predictions.size();
+ if (!predictions.empty()) {
+ total_triggers++;
+ }
+ EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
+ fields.begin() + 1, fields.end())));
+ }
+
+ LOG(INFO) << "Responses: " << total_responses << " / " << total_items;
+ LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items;
+ EXPECT_EQ(total_triggers, total_items);
+}
+
+} // namespace
+} // namespace smartreply
+} // namespace custom
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
new file mode 100644
index 0000000000..f5d1f436bc
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech Hotword model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+void RunTest(int model_input_tensor, int svdf_layer_state_tensor,
+ int model_output_tensor, const string& model_name,
+ const string& golden_in_name, const string& golden_out_name) {
+ // Read the model.
+ string tflite_file_path = file::JoinPath(TestDataPath(), model_name);
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to read model from file " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Reset the SVDF layer state.
+ memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0,
+ interpreter->tensor(svdf_layer_state_tensor)->bytes);
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path = file::JoinPath(TestDataPath(), golden_in_name);
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), golden_out_name);
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(model_input_tensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(model_input_tensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(model_output_tensor)->dims->data[1];
+ const int input_sequence_size =
+ input_frames[0].size() / (speech_input_size * speech_batch_size);
+ float* input_ptr = interpreter->tensor(model_input_tensor)->data.f;
+ float* output_ptr = interpreter->tensor(model_output_tensor)->data.f;
+
+ // The first layer (SVDF) input size is 40 (speech_input_size). Each speech
+ // input frames for this model is 1280 floats, which can be fed to input in a
+ // sequence of size 32 (input_sequence_size).
+ for (int i = 0; i < TestInputSize(input_frames); i++) {
+ int frame_ptr = 0;
+ for (int s = 0; s < input_sequence_size; s++) {
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ interpreter->Invoke();
+ }
+ // After the whole frame (1280 floats) is fed, we can check the output frame
+ // matches with the golden output frame.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+TEST(SpeechHotword, OkGoogleTestRank1) {
+ constexpr int kModelInputTensor = 0;
+ constexpr int kSvdfLayerStateTensor = 4;
+ constexpr int kModelOutputTensor = 18;
+
+ RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
+ "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank1.csv");
+}
+
+TEST(SpeechHotword, OkGoogleTestRank2) {
+ constexpr int kModelInputTensor = 17;
+ constexpr int kSvdfLayerStateTensor = 1;
+ constexpr int kModelOutputTensor = 18;
+ RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
+ "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank2.csv");
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
new file mode 100644
index 0000000000..687cfab0b2
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech SpeakerId model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 19;
+constexpr int kLstmLayer1CellStateTensor = 20;
+constexpr int kLstmLayer2OutputStateTensor = 40;
+constexpr int kLstmLayer2CellStateTensor = 41;
+constexpr int kLstmLayer3OutputStateTensor = 61;
+constexpr int kLstmLayer3CellStateTensor = 62;
+constexpr int kModelOutputTensor = 66;
+
+TEST(SpeechSpeakerId, OkGoogleTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to read model from file " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ::tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, resolver)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_speakerid_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
new file mode 100644
index 0000000000..30d89a1354
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech TERSE AM model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 19;
+constexpr int kLstmLayer1CellStateTensor = 20;
+constexpr int kLstmLayer2OutputStateTensor = 40;
+constexpr int kLstmLayer2CellStateTensor = 41;
+constexpr int kLstmLayer3OutputStateTensor = 61;
+constexpr int kLstmLayer3CellStateTensor = 62;
+constexpr int kLstmLayer4OutputStateTensor = 82;
+constexpr int kLstmLayer4CellStateTensor = 83;
+constexpr int kLstmLayer5OutputStateTensor = 103;
+constexpr int kLstmLayer5CellStateTensor = 104;
+constexpr int kModelOutputTensor = 109;
+
+TEST(SpeechTerseAm, RandomIOTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to mmap model " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer4CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer5CellStateTensor)->bytes);
+
+
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
new file mode 100644
index 0000000000..e6f2673a42
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech TTS model using TFLite Ops.
+
+#include <string.h>
+
+#include <memory>
+#include <string>
+
+#include "base/logging.h"
+#include "file/base/path.h"
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/models/test_utils.h"
+
+namespace tflite {
+namespace models {
+
+constexpr int kModelInputTensor = 0;
+constexpr int kLstmLayer1OutputStateTensor = 25;
+constexpr int kLstmLayer1CellStateTensor = 26;
+constexpr int kLstmLayer2OutputStateTensor = 46;
+constexpr int kLstmLayer2CellStateTensor = 47;
+constexpr int kLstmLayer3OutputStateTensor = 67;
+constexpr int kLstmLayer3CellStateTensor = 68;
+constexpr int kRnnLayerHiddenStateTensor = 73;
+constexpr int kModelOutputTensor = 74;
+
+TEST(SpeechTTS, RandomIOTest) {
+ // Read the model.
+ string tflite_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model.tflite");
+ auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
+ CHECK(model) << "Failed to mmap model " << tflite_file_path;
+
+ // Initialize the interpreter.
+ ops::builtin::BuiltinOpResolver builtins;
+ std::unique_ptr<Interpreter> interpreter;
+ InterpreterBuilder(*model, builtins)(&interpreter);
+ CHECK(interpreter != nullptr);
+ interpreter->AllocateTensors();
+
+ // Load the input frames.
+ Frames input_frames;
+ const string input_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model_in.csv");
+ ReadFrames(input_file_path, &input_frames);
+
+ // Load the golden output results.
+ Frames output_frames;
+ const string output_file_path =
+ file::JoinPath(TestDataPath(), "speech_tts_model_out.csv");
+ ReadFrames(output_file_path, &output_frames);
+
+ const int speech_batch_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[0];
+ const int speech_input_size =
+ interpreter->tensor(kModelInputTensor)->dims->data[1];
+ const int speech_output_size =
+ interpreter->tensor(kModelOutputTensor)->dims->data[1];
+
+ float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
+ float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
+
+ // Clear the LSTM state for layers.
+ memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
+ memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
+ interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
+
+ memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0,
+ interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes);
+
+ for (int i = 0; i < input_frames.size(); i++) {
+ // Feed the input to model.
+ int frame_ptr = 0;
+ for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
+ input_ptr[k] = input_frames[i][frame_ptr++];
+ }
+ // Run the model.
+ interpreter->Invoke();
+ // Validate the output.
+ for (int k = 0; k < speech_output_size; k++) {
+ ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
+ }
+ }
+}
+
+} // namespace models
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h
new file mode 100644
index 0000000000..b2596babd0
--- /dev/null
+++ b/tensorflow/contrib/lite/models/test_utils.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <fstream>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace models {
+using Frames = std::vector<std::vector<float>>;
+} // namespace models
+} // namespace tflite
+
+#ifndef __ANDROID__
+#include "file/base/path.h"
+#include "tensorflow/core/platform/test.h"
+
+inline string TestDataPath() {
+ return string(file::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
+ "contrib/lite/models/testdata/"));
+}
+inline int TestInputSize(const tflite::models::Frames& input_frames) {
+ return input_frames.size();
+}
+#else
+inline string TestDataPath() {
+ return string("third_party/tensorflow/contrib/lite/models/testdata/");
+}
+
+inline int TestInputSize(const tflite::models::Frames& input_frames) {
+ // Android TAP is very slow, we only test the first 20 frames.
+ return 20;
+}
+#endif
+
+namespace tflite {
+namespace models {
+
+// Read float data from a comma-separated file:
+// Each line will be read into a float vector.
+// The return result will be a vector of float vectors.
+void ReadFrames(const string& csv_file_path, Frames* frames) {
+ std::ifstream csv_file(csv_file_path);
+ string line;
+ while (std::getline(csv_file, line, '\n')) {
+ std::vector<float> fields;
+ // Used by strtok_r internaly for successive calls on the same string.
+ char* save_ptr = nullptr;
+
+ // Tokenize the line.
+ char* next_token =
+ strtok_r(const_cast<char*>(line.c_str()), ",", &save_ptr);
+ while (next_token != nullptr) {
+ float f = strtod(next_token, nullptr);
+ fields.push_back(f);
+ next_token = strtok_r(nullptr, ",", &save_ptr);
+ }
+ frames->push_back(fields);
+ }
+ csv_file.close();
+}
+
+} // namespace models
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
diff --git a/tensorflow/contrib/lite/nnapi/BUILD b/tensorflow/contrib/lite/nnapi/BUILD
new file mode 100644
index 0000000000..402f1e949b
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi/BUILD
@@ -0,0 +1,25 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [
+ "//visibility:public",
+])
+
+cc_library(
+ name = "nnapi_lib",
+ hdrs = [
+ "NeuralNetworksShim.h",
+ ],
+ linkopts = ["-ldl"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
new file mode 100644
index 0000000000..5d06165772
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -0,0 +1,1916 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef NN_API_SHIM_H0
+#define NN_API_SHIM_H0
+
+#include <dlfcn.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+// helpers
+
+#define NNAPI_LOG(format, ...) printf(format "\n", __VA_ARGS__);
+#define LOAD_FUNCTION(name) \
+ static name##_fn fn = reinterpret_cast<name##_fn>(loadFunction(#name));
+#define EXECUTE_FUNCTION(...) \
+ if (fn != nullptr) { \
+ fn(__VA_ARGS__); \
+ }
+#define EXECUTE_FUNCTION_RETURN(...) return fn != nullptr ? fn(__VA_ARGS__) : 0;
+
+inline void* loadLibrary(const char* name) {
+ // TODO: change RTLD_LOCAL? Assumes there can be multiple instances of nn
+ // api RT
+ void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
+ if (handle == nullptr) {
+ NNAPI_LOG("nnapi error: unable to open library %s", name);
+ }
+ return handle;
+}
+
+inline void* getLibraryHandle() {
+ static void* handle = loadLibrary("libneuralnetworks.so");
+ return handle;
+}
+
+inline void* loadFunction(const char* name) {
+ void* fn = nullptr;
+ if (getLibraryHandle() != nullptr) {
+ fn = dlsym(getLibraryHandle(), name);
+ }
+ if (fn == nullptr) {
+ NNAPI_LOG("nnapi error: unable to open function %s", name);
+ }
+ return fn;
+}
+
+inline bool NNAPIExists() {
+ static bool nnapi_is_available = getLibraryHandle();
+ return nnapi_is_available;
+}
+
+// nn api types
+
+/**
+ * Operand types.
+ *
+ * The type of operands that can be added to a model.
+ *
+ * Although we define many types, most operators accept just a few
+ * types. Most used are ANEURALNETWORKS_TENSOR_FLOAT32,
+ * ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, and ANEURALNETWORKS_INT32.
+ */
+enum {
+ /** The following entries are used to declare scalars. */
+
+ /** A 32 bit floating point scalar value. */
+ ANEURALNETWORKS_FLOAT32 = 0,
+ /** A signed 32 bit integer scalar value. */
+ ANEURALNETWORKS_INT32 = 1,
+ /** An unsigned 32 bit integer scalar value. */
+ ANEURALNETWORKS_UINT32 = 2,
+
+ /** The following entries are used to declare tensors. */
+
+ /** A tensor of 32 bit floating point values. */
+ ANEURALNETWORKS_TENSOR_FLOAT32 = 3,
+ /** A tensor of 32 bit integer values. */
+ ANEURALNETWORKS_TENSOR_INT32 = 4,
+ /** A tensor of 8 bit integers that represent real numbers.
+ *
+ * Attached to this tensor are two numbers that can be used to convert
+ * the 8 bit integer to the real value and vice versa. These two numbers are:
+ * - scale: a 32 bit floating point value
+ * - zero_value: an 32 bit integer
+ *
+ * The formula is:
+ * real_value = (integer_value - zero_value) * scale.
+ */
+ ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
+};
+
+/**
+ * Operation types.
+ *
+ * The type of operations that can be added to a model.
+ */
+enum {
+ /** Adds two tensors, elment-wise.
+ *
+ * Takes two input tensors of identical type and compatible dimensions. The
+ * output is the sum of both input tensors, optionally modified by an
+ * activation function.
+ *
+ * Two dimensions are compatible when:
+ * 1. they are equal, or
+ * 2. one of them is 1
+ *
+ * The size of the output is the maximum size along each dimension of the
+ * input operands. It starts with the trailing dimensions, and works its way
+ * forward.
+ *
+ * Example:
+ *
+ * input1.dimension = {4, 1, 2}
+ * input2.dimension = {5, 4, 3, 1}
+ * output.dimension = {5, 4, 3, 2}
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor.
+ * * 1: A tensor of the same type, and compatible dimensions as input0.
+ * * 2: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The sum, a tensor of the same type as input0.
+ */
+ ANEURALNETWORKS_ADD = 0,
+ /** Performs a 2-D average pooling operation.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * sum_{i, j}(input[batch, row + i, col + j, channel]) / sum(1)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 6: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the filter width.
+ * * 8: An INT32 value, specifying the filter height.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_AVERAGE_POOL_2D = 1,
+ /** Concatenates the input tensors along the given dimension.
+ *
+ * The input tensors must have identical type and the same dimensions except
+ * the dimension along the concatenation axis.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * 0 ~ n: The list on n input tensors, of shape [D0, D1, ..., Daxis(i), ...,
+ * Dm] n+1: An INT32 value, specifying the concatenation axis. n+2: An INT32
+ * value, and has to be one of the {@link FuseCode} values. Specifies the
+ * activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output, a tensor of the same type as the input tensors.
+ * The output shape is [D0, D1, ..., sum(Daxis(i)), ..., Dm].
+ */
+ ANEURALNETWORKS_CONCATENATION = 2,
+ /** Performs an 2-D convolution operation.
+ *
+ * The CONV_2D op sweeps a 2-D filter that can mix channels together over a
+ * batch of images, applying the filter to each window of each image of the
+ * appropriate size.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * sum_{i, j} (
+ * input[batch, row + i, col + j, k] *
+ * filter[channel, row + i, col + j, k] +
+ * bias[channel]
+ * )
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width,
+ * depth_in], specifying the filter.
+ * * 2: A 1-D tensor, of shape [depth_out], specifying the bias.
+ * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
+ * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
+ * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
+ * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
+ * * 3: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 5: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 8: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth_out].
+ */
+ ANEURALNETWORKS_CONV_2D = 3,
+ /** Performs a depthwise 2-D convolution operation.
+ *
+ * Given an input tensor of shape [batches, height, width, depth_in] and a
+ * filter tensor of shape [depth_out, filter_height, filter_width, depth_in]
+ * containing in_channels convolutional filters of depth 1, DEPTHWISE_CONV
+ * applies a different filter to each input channel (expanding from 1 channel
+ * to channel_multiplier channels for each), then concatenates the results
+ * together.
+ *
+ * The output has depth_out = depth_in * depth_multiplier channels.
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[b, i, j, k * channel_multiplier + q] =
+ * sum_{di, dj} (
+ * input[b, strides[1] * i + di, strides[2] * j + dj, k] *
+ * filter[di, dj, k, q]
+ * )
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width,
+ * depth_in], specifying the filter.
+ * * 2: A 1-D tensor, of shape [depth_out], specifying the bias.
+ * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
+ * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
+ * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
+ * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
+ * * 3: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 5: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 8: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 9: An INT32 value, specifying the depthwise multiplier.
+ * * 10: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth_out].
+ */
+ ANEURALNETWORKS_DEPTHWISE_CONV_2D = 4,
+ /** Rearranges data from depth into blocks of spatial data.
+ *
+ * More specifically, this op outputs a copy of the input tensor where values
+ * from the depth dimension are moved in spatial blocks to the height and
+ * width dimensions. The value block_size indicates the input block size and
+ * how the data is moved.
+ *
+ * Chunks of data of size block_size * block_size from depth are rearranged
+ * into non-overlapping blocks of size block_size x block_size.
+ *
+ * The width of the output tensor is input_depth * block_size, whereas the
+ * height is input_height * block_size. The depth of the input tensor must be
+ * divisible by block_size * block_size
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and
+ * block_size * block_size must be a divisor of the input depth.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batch, height*block_size,
+ * width*block_size, depth/(block_size*block_size)].
+ */
+ ANEURALNETWORKS_DEPTH_TO_SPACE = 5,
+ /** Dequantizes the input tensor.
+ *
+ * The formula is:
+ *
+ * output = (input - zero_value) * scale.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor of type {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0, but with type
+ * {@link ANEURALNETWORKS_TENSOR_FLOAT32}.
+ */
+ ANEURALNETWORKS_DEQUANTIZE = 6,
+
+ /**
+ * Looks up items from a given tensor.
+ *
+ * Each item in the output is a raw copy of the corresponding item in
+ * the input “values”. If the the given “lookup” indices are out of bounds,
+ * the op will fail and an error will be reported.
+ *
+ * Inputs:
+ * * 0: Values. An n-D tensor of any type X (where n >= 2). E.g., if n is 2,
+ * then the shape would be [lookup_dimension, values_dimension], where
+ * “lookup_dimension” corresponds to the indexing dimension in the lookup
+ * table, and “values_dimension” to the contents.
+ * * 1: Lookups. An 1-D tensor of type T, of shape [lookup_size], where
+ * “lookup_size” is the number of elements to look for, and each entry
+ * corresponds to the first dimension of the “values” tensor.
+ *
+ * Output:
+ * * 0: A n-D tensor of type X and the same rank and shape as the “values”
+ * tensor, except for the first dimension which has size “lookup_size”.
+ */
+ ANEURALNETWORKS_EMBEDDING_LOOKUP = 7,
+
+ /** Computes element-wise floor() on the input tensor.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor.
+ *
+ * Outputs:
+ * * 0: The output, a tensor of the same type and dimensions as input0.
+ */
+ ANEURALNETWORKS_FLOOR = 8,
+ /** Denotes a fully (densely) connected layer, which connects all elements in
+ * the input tensor with each element in the output tensor.
+ *
+ * This layer implements the operation:
+ *
+ * outputs = activation(inputs * weights’ + bias)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input. If rank is greater than 2, then it
+ * gets flattened to a 2-D Tensor. The 2-D Tensor is handled as if dimensions
+ * corresponded to shape [batch_size, input_size], where “batch_size”
+ * corresponds to the batching dimension, and “input_size” is the size of the
+ * input.
+ * * 1: A 2-D tensor, specifying the weights, of shape [num_units,
+ * input_size], where "num_units" corresponds to the number of output nodes.
+ * * 2: A 1-D tensor, of shape [num_units], specifying the bias.
+ * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
+ * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
+ * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
+ * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
+ * * 3: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output tensor, of shape [batch_size, num_units].
+ */
+ ANEURALNETWORKS_FULLY_CONNECTED = 9,
+
+ /**
+ * Looks up values of a hash table with given keys.
+ *
+ * Inputs:
+ * * 0: Lookups. A 1-D int32 tensor with shape [ k ].
+ * * 1: Keys. A 1-D int32 tensor with shape [ n ], *MUST* be sorted in
+ * ascending order.
+ * * 2: Values. A tensor with shape [ n … ].
+ *
+ * Outputs:
+ * * 0: Output. A tensor with shape [ k …].
+ * * 1: Hits. A uint8 tensor with shape [ k ] indicates whether the lookup
+ * hits or not.
+ */
+ ANEURALNETWORKS_HASHTABLE_LOOKUP = 10,
+
+ /** Applies L2 normalization along the depth dimension.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * input[batch, row, col, channel] /
+ * sqrt(sum_{c} pow(input[batch, row, col, c], 2))
+ *
+ * For x with more dimensions, independently normalizes each 1-D slice along
+ * dimension dim.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_L2_NORMALIZATION = 11,
+
+ /** Performs an 2-D L2 pooling operation.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * sqrt(sum_{i, j} pow(input[batch, row + i, col + j, channel], 2) /
+ * sum(1))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 6: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the filter width.
+ * * 8: An INT32 value, specifying the filter height.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_L2_POOL_2D = 12,
+ /** Applies Local Response Normalization along the depth dimension.
+ *
+ * The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the
+ * last dimension), and each vector is normalized independently. Within a
+ * given vector, each component is divided by the weighted, squared sum of
+ * inputs within depth_radius.
+ *
+ * The output is calculated using this formula:
+ *
+ * sqr_sum[a, b, c, d] =
+ * sum(pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2)
+ * output = input / pow((bias + alpha * sqr_sum), beta)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the radius of the normalization window.
+ * * 2: A FLOAT32 value, specifying the bias, must not be zero.
+ * * 3: A FLOAT32 value, specifying the scale factor, alpha.
+ * * 4: A FLOAT32 value, specifying the exponent, beta.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION = 13,
+ /** Computes sigmoid activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = 1 / (1 + exp(-input))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_LOGISTIC = 14,
+
+ /**
+ * Projects an input to a bit vector via locality senstive hashing.
+ *
+ * Inputs:
+ * * 0: Hash functions. Dim.size == 2, DataType: Float.
+ * Tensor[0].Dim[0]: Number of hash functions.
+ * Tensor[0].Dim[1]: Number of seeds per hash functions.
+ * Tensor[0].Dim[1] <= 32 in sparse case.
+ *
+ * * 1: Input. Dim.size >= 1, no restriction on DataType.
+ * * 2: Weight. Optional. Dim.size == 1, DataType: Float.
+ * If not set, each input element is considered to have the same weight of
+ * 1.0.
+ * Tensor[1].Dim[0] == Tensor[2].Dim[0]
+ * * 3: Type:
+ * Sparse: Value LSHProjectionType_SPARSE(=1).
+ * Computed bit vector is considered to be sparse.
+ * Each output element is an int32 made up of multiple bits computed
+ * from hash functions.
+ *
+ * Dense: Value LSHProjectionType_DENSE(=2).
+ * Computed bit vector is considered to be dense. Each output element
+ * represents a bit and can take the value of either 0 or 1.
+ *
+ * Outputs:
+ * * 0: If the projection type is sparse:
+ * Output.Dim == { Tensor[0].Dim[0] }
+ * A tensor of int32 that represents hash signatures.
+ * If the projection type is Dense:
+ * Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
+ * A flattened tensor that represents projected bit vectors.
+ */
+ ANEURALNETWORKS_LSH_PROJECTION = 15,
+
+ /**
+ * Long short-term memory unit (LSTM) recurrent network layer.
+ *
+ * The default non-peephole implementation is based on:
+ * http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+ * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural
+ * Computation, 9(8):1735-1780, 1997.
+ *
+ * The peephole implementation is based on:
+ * https://research.google.com/pubs/archive/43905.pdf
+ * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory
+ * recurrent neural network architectures for large scale acoustic modeling."
+ * INTERSPEECH, 2014.
+ *
+ * The coupling of input and forget gate (CIFG) is based on:
+ * http://arxiv.org/pdf/1503.04069.pdf
+ * Greff et al. "LSTM: A Search Space Odyssey"
+ *
+ * The class has the following independently optional inputs:
+ * * If input gate (if CIFG): “input_to_forget_weights”,
+ * “recurrent_to_input_weights”, “cell_to_input_weights”, “input_gate_bias”.
+ * * If no peephole connections: “cell_to_input_weights”,
+ * “cell_to_forget_weights”, “cell_to_output_weights”.
+ * * If no projection layer: “projection_weights” and “projection_bias”.
+ * * If no projection bias: “projection_bias”.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Inputs:
+ * * 0: Input.
+ * A 2-D tensor of type T, of shape [batch_size, input_size], where
+ * “batch_size” corresponds to the batching dimension, and “input_size”
+ * is the size of the input.
+ * * 1: input_to_input_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size], where
+ * “num_units” corresponds to the number of cell units.
+ * * 2: input_to_forget_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size].
+ * * 3: input_to_cell_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size].
+ * * 4: input_to_output_weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size].
+ * * 5: recurrent_to_input_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size], where
+ * “output_size” corresponds to either the number of cell units (i.e.,
+ * “num_units”), or the second dimension of the “projection_weights”, if
+ * defined.
+ * * 6: recurrent_to_forget_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size].
+ * * 7: recurrent_to_cell_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size].
+ * * 8: recurrent_to_output_weights.
+ * A 2-D tensor of type T, of shape [num_units, output_size].
+ * * 9: cell_to_input_weights.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 10:cell_to_forget_weights.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 11:cell_to_output_weights.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 12:input_gate_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 13:forget_gate_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 14:cell_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 15:output_gate_bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ * * 16:projection_weights.
+ * A 2-D tensor of type T, of shape [output_size, num_units].
+ * * 17:projection_bias.
+ * A 1-D tensor of type T, of shape [output_size].
+ *
+ * Parameters:
+ * * 18:fused_activation_function.
+ * An (optional) ActivationFunctionType indicating the activation
+ * function.
+ * If “NONE” is specified then it results in a linear activation.
+ * * 19:cell_clip.
+ * A clipping threshold for the cell state, such that values are bound
+ * within [-cell_clip, cell_clip]. If set to 0.0 then clipping is
+ * disabled.
+ * * 20:proj_clip.
+ * A clipping threshold for the output from the projection layer, such
+ * that values are bound within [-proj_clip, proj_clip]. If set to 0.0
+ * then clipping is disabled.
+ *
+ * Outputs:
+ * * 0: scratch_buffer.
+ * A 3-D tensor of type T, of shape [batch_size, num_cell, 4].
+ * * 1: output_state.
+ * A 2-D tensor of type T, of shape [batch_size, output_size].
+ * * 2: cell_state.
+ * A 2-D tensor of type T, of shape [batch_size, num_units].
+ * * 3: output.
+ * A 2-D tensor of type T, of shape [batch_size, output_size]. This is
+ * effectively the same as the current “output_state” value.
+ */
+ ANEURALNETWORKS_LSTM = 16,
+
+ /** Performs an 2-D max pooling operation.
+ *
+ * The output dimensions are functions of the filter dimensions, stride, and
+ * padding.
+ *
+ * The values in the output tensor are computed as:
+ *
+ * output[batch, row, col, channel] =
+ * max_{i, j} (input[batch, row + i, col + j, channel])
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
+ * dimension.
+ * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
+ * dimension.
+ * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
+ * dimension.
+ * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
+ * dimension.
+ * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
+ * * 6: An INT32 value, specifying the output stride in the ‘height’
+ * dimension.
+ * * 7: An INT32 value, specifying the filter width.
+ * * 8: An INT32 value, specifying the filter height.
+ * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
+ * depth].
+ */
+ ANEURALNETWORKS_MAX_POOL_2D = 17,
+
+ /** Multiplies two tensors, elment-wise.
+ *
+ * Takes two input tensors of identical type and compatible dimensions. The
+ * output is the product of both input tensors, optionally modified by an
+ * activation function.
+ *
+ * Two dimensions are compatible when:
+ * 1. they are equal, or
+ * 2. one of them is 1
+ *
+ * The size of the resulting output is the maximum size along each dimension
+ * of the input operands. It starts with the trailing dimensions, and works
+ * its way forward.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4
+ *
+ * Inputs:
+ * * 0: A tensor.
+ * * 1: A tensor of the same type, and compatible dimensions as input0.
+ * * 2: An INT32 value, and has to be one of the {@link FuseCode} values.
+ * Specifies the activation to invoke on the result of each addition.
+ *
+ * Outputs:
+ * * 0: The product, a tensor of the same type as input0.
+ */
+ ANEURALNETWORKS_MUL = 18,
+ /** Computes rectified linear activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = max(0, input)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_RELU = 19,
+ /** Computes rectified linear 1 activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = min(1.f, max(-1.f, input))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_RELU1 = 20,
+ /** Computes rectified linear 6 activation on the input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = min(6, max(0, input))
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_RELU6 = 21,
+ /** Reshapes a tensor.
+ *
+ * Given tensor, this operation returns a tensor that has the same values as
+ * tensor, but with a newly specified shape.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the tensor to be reshaped.
+ * * 1: A 1-D tensor of type {@link ANEURALNETWORKS_TENSOR_INT32}, defining
+ * the shape of the output tensor. The number of elements implied by shape
+ * must be the same as the number of elements in the input tensor.
+ *
+ * Outputs:
+ * * 0: The output tensor, of shape specified by the input shape.
+ */
+ ANEURALNETWORKS_RESHAPE = 22,
+ /** Resizes images to given size using the bilinear interpretation.
+ *
+ * Resized images will be distorted if their original aspect ratio is not the
+ * same as input.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
+ * input.
+ * * 1: An INT32 value, specifying the output width of the output tensor.
+ * * 2: An INT32 value, specifying the output height of the output tensor.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batches, new_height, new_width,
+ * depth].
+ */
+ ANEURALNETWORKS_RESIZE_BILINEAR = 23,
+
+ /**
+ * A basic recurrent neural network layer.
+ *
+ * This layer implements the operation:
+ * outputs = state = activation(inputs * input_weights + state *
+ * recurrent_weights + bias)
+ *
+ * Where:
+ * * “input_weights” is a weight matrix that multiplies the inputs;
+ * * “recurrent_weights” is a weight matrix that multiplies the current
+ * “state” which itself is the output from the previous time step
+ * computation;
+ * * “bias” is a bias vector (added to each output vector in the batch);
+ * * “activation” is the function passed as the “fused_activation_function”
+ * argument (if not “NONE”).
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Inputs:
+ * * 0: input.
+ * A 2-D tensor of type T, of shape [batch_size, input_size], where
+ * “batch_size” corresponds to the batching dimension, and “input_size”
+ * is the size of the input.
+ * * 1: weights.
+ * A 2-D tensor of type T, of shape [num_units, input_size], where
+ * “num_units” corresponds to the number of units.
+ * * 2: recurrent_weights.
+ * A 2-D tensor of type T, of shape [num_units, num_units], with columns
+ * corresponding to the weights from each unit.
+ * * 3: bias.
+ * A 1-D tensor of type T, of shape [num_units].
+ *
+ * For FLOAT32 input tensor, bias must also be FLOAT32.
+ * For UINT8 input tensor, bias must be INT32.
+ *
+ * Parameters
+ * * 4: fused_activation_function.
+ * An (optional) ActivationFunctionType indicating the activation
+ * function. If “NONE” is specified then it results in a linear
+ * activation.
+ *
+ * * 5: Hidden state.
+ * A 2-D tensor of type T, of shape [batch_size, num_units].
+ *
+ * Outputs:
+ * * 0: output.
+ * A 2-D tensor of type T, of shape [batch_size, num_units]. This is
+ * effectively the same as the current state value.
+ */
+ ANEURALNETWORKS_RNN = 24,
+
+ /** Computes the softmax activation on the input tensor element-wise, per
+ * batch, by normalizing the input vector so the maximum coefficient is zero.
+ *
+ * The output is calculated using this formula:
+ *
+ * output[batch, i] =
+ * exp((input[batch, i] - max(input[batch, :])) * beta) /
+ * sum_{k}{exp((input[batch, k] - max(input[batch, :])) * beta)}
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 2 or 4.
+ *
+ * Inputs:
+ * * 0: A 2-D or 4-D tensor, specifying the tensor to be reshaped.
+ * * 1: A FLOAT32 value, specifying the scaling factor for the exponent, beta.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_SOFTMAX = 25,
+
+ /** Rearranges blocks of spatial data, into depth.
+ *
+ * More specifically, this op outputs a copy of the input tensor where values
+ * from the height and width dimensions are moved to the depth dimension. The
+ * value block_size indicates the input block size and how the data is moved.
+ *
+ * Chunks of data of size block_size * block_size from depth are rearranged
+ * into non-overlapping blocks of size block_size x block_size.
+ *
+ * The depth of the output tensor is input_depth * block_size * block_size.
+ * The input tensor's height and width must be divisible by block_size.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
+ *
+ * Supported tensor rank: 4, with "NHWC" data layout.
+ *
+ * Inputs:
+ * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
+ * the input.
+ * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and
+ * block_size must be a divisor of both the input height and width.
+ *
+ * Outputs:
+ * * 0: The output 4-D tensor, of shape [batch, height/block_size,
+ * width/block_size, depth*block_size*block_size].
+ */
+ ANEURALNETWORKS_SPACE_TO_DEPTH = 26,
+
+ /**
+ * SVDF op is a kind of stateful layer derived from the notion that a
+ * densely connected layer that's processing a sequence of input frames can
+ * be approximated by using a singular value decomposition of each of its
+ * nodes. The implementation is based on:
+ *
+ * https://research.google.com/pubs/archive/43813.pdf
+ *
+ * P. Nakkiran, R. Alvarez, R. Prabhavalkar, C. Parada.
+ * “Compressing Deep Neural Networks using a Rank-Constrained Topology”.
+ * INTERSPEECH, 2015.
+ *
+ * It processes the incoming input using a 2-stage filtering mechanism:
+ * * stage 1 performs filtering on the "features" dimension, whose outputs get
+ * pushed into a memory of fixed-size memory_size.
+ * * stage 2 performs filtering on the "time" dimension of the memory_size
+ * memoized outputs of stage 1.
+ *
+ * Specifically, for rank 1, this layer implements the operation:
+ *
+ * memory = push(conv1d(inputs, weights_feature, feature_dim, "VALID"));
+ * outputs = activation(memory * weights_time + bias);
+ *
+ * Where:
+ * * “weights_feature” is a weights matrix that processes the inputs (by
+ * convolving the input with every “feature filter”), and whose outputs get
+ * pushed, stacked in order, into the fixed-size “memory” (the oldest entry
+ * gets dropped);
+ * * “weights_time” is a weights matrix that processes the “memory” (by a
+ * batched matrix multiplication on the num_units);
+ * * “bias” is an optional bias vector (added to each output vector in the
+ * batch); and
+ * * “activation” is the function passed as the “fused_activation_function”
+ * argument (if not “NONE”).
+ *
+ * Each rank adds a dimension to the weights matrices by means of stacking
+ * the filters.
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Inputs:
+ * * 0: input.
+ * A 2-D tensor of type T, of shape [batch_size, input_size], where
+ * “batch_size” corresponds to the batching dimension, and “input_size”
+ * is the size of the input.
+ * * 1: weights_feature.
+ * A 2-D tensor of type T, of shape [num_units, input_size], where
+ * “num_units” corresponds to the number of units.
+ * * 2: weights_time.
+ * A 2-D tensor of type T, of shape [num_units, memory_size], where
+ * “memory_size” corresponds to the fixed-size of the memory.
+ * * 3: bias.
+ * A optional 1-D tensor of type T, of shape [num_units].
+ *
+ * For FLOAT32 input tensor, bias must also be FLOAT32.
+ * For UINT8 input tensor, bias must be INT32.
+ *
+ * Parameters:
+ * * 4: rank.
+ * The rank of the SVD approximation.
+ * * 5: fused_activation_function.
+ * An (optional) ActivationFunctionType indicating the activation
+ * function. If “NONE” is specified then it results in a linear activation.
+ *
+ * Outputs:
+ * * 0: state.
+ * A 2-D tensor of type T, of shape [batch_size, (memory_size - 1) *
+ * num_units * rank].
+ * * 1: output.
+ * A 2-D tensor of type T, of shape [batch_size, num_units].
+ */
+ ANEURALNETWORKS_SVDF = 27,
+
+ /** Computes hyperbolic tangent of input tensor element-wise.
+ *
+ * The output is calculated using this formula:
+ *
+ * output = tanh(input)
+ *
+ * Supported tensor types:
+ * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ *
+ * Supported tensor rank: up to 4.
+ *
+ * Inputs:
+ * * 0: A tensor, specifying the input.
+ *
+ * Outputs:
+ * * 0: The output tensor of same shape as input0.
+ */
+ ANEURALNETWORKS_TANH = 28,
+};
+
+/**
+ * Fused activation function types.
+ *
+ */
+enum {
+ /** NO fused activation function. */
+ ANEURALNETWORKS_FUSED_NONE = 0,
+ /** Fused ReLU activation function. */
+ ANEURALNETWORKS_FUSED_RELU = 1,
+ /** Fused ReLU1 activation function. */
+ ANEURALNETWORKS_FUSED_RELU1 = 2,
+ /** Fused ReLU6 activation function. */
+ ANEURALNETWORKS_FUSED_RELU6 = 3,
+};
+
+/**
+ * Execution preferences.
+ */
+enum {
+ /**
+ * Prefer executing in a way that minimizes battery drain.
+ * This is desirable for compilations that will be executed often.
+ */
+ ANEURALNETWORKS_PREFER_LOW_POWER = 0,
+ /**
+ * Prefer returning a single answer as fast as possible, even if this causes
+ * more power consumption.
+ */
+ ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1,
+ /**
+ * Prefer maximizing the throughput of successive frames, for example when
+ * processing successive frames coming from the camera.
+ */
+ ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2,
+};
+
+/**
+ * Result codes.
+ */
+enum {
+ ANEURALNETWORKS_NO_ERROR = 0,
+ ANEURALNETWORKS_OUT_OF_MEMORY = 1,
+ ANEURALNETWORKS_INCOMPLETE = 2,
+ ANEURALNETWORKS_UNEXPECTED_NULL = 3,
+ ANEURALNETWORKS_BAD_DATA = 4,
+ ANEURALNETWORKS_OP_FAILED = 5,
+ ANEURALNETWORKS_UNMAPPABLE = 5,
+ ANEURALNETWORKS_BAD_STATE = 6,
+};
+
+/**
+ * ANeuralNetworksMemory is an opaque type that represents memory.
+ *
+ * This type is used to represent shared memory, memory mapped files,
+ * and similar memories.
+ *
+ * By using shared memory, a program can efficiently communicate to the
+ * runtime and drivers the tensors that define a model. See
+ * {@link ANeuralNetworksModel_setOperandValueFromMemory}. An application
+ * should typically create one shared memory object that contains every tensor
+ * needed to define a model. {@link ANeuralNetworksMemory_createFromFd} can be
+ * used to create shared memory from a file handle. {@link
+ * ANeuralNetworksMemory_createShared} can be used to directly created shared
+ * memory.
+ *
+ * Memory objects can also be used to specify the input and output arguments of
+ * an execution. See {@link ANeuralNetworksExecution_setInputFromMemory}
+ * and {@link ANeuralNetworksExecution_setOutputFromMemory}.
+ */
+typedef struct ANeuralNetworksMemory ANeuralNetworksMemory;
+
+/**
+ * ANeuralNetworksModel is an opaque type that contains a description of the
+ * mathematical operations that constitute the model.
+ *
+ * <p>The model will be built by calling<ul>
+ * <li>{@link ANeuralNetworksModel_create},</li>
+ * <li>{@link ANeuralNetworksModel_addOperation},</li>
+ * <li>{@link ANeuralNetworksModel_addOperand},</li>
+ * </ul>
+ *
+ * A model is completed by calling {@link ANeuralNetworksModel_finish}.
+ * A model is destroyed by calling {@link ANeuralNetworksModel_free}.
+ *
+ * <p>It is the application's responsibility to make sure that only one thread
+ * modifies a model at a given time. It is however safe for more than one
+ * thread to use the model once {@link ANeuralNetworksModel_finish} has
+ * returned.</p>
+ *
+ * <p>It is also the application's responsibility to ensure that there are no
+ * other uses of the model after calling {@link ANeuralNetworksModel_free}. This
+ * includes any compilation or execution object created using the model.</p>
+ */
+typedef struct ANeuralNetworksModel ANeuralNetworksModel;
+
+/**
+ * ANeuralNetworksCompilation is an opaque type that can be used to compile
+ * a machine learning model.
+ *
+ * <p>To use:<ul>
+ * <li>Create a new compilation instance by calling the
+ * {@link ANeuralNetworksCompilation_create} function.</li>
+ * <li>Perform the compilation with {@link
+ * ANeuralNetworksCompilation_start}.</li> <li>Wait for the compilation to
+ * complete with {@link ANeuralNetworksCompilation_wait}.</li> <li>Use the
+ * compilation as many times as needed with {@link
+ * ANeuralNetworksExecution_create}.</li> <li>Destroy the compilation with
+ * {@link ANeuralNetworksCompilation_free} once all executions using the
+ * compilation have completed.</li></ul></p>
+ *
+ * <p>A compilation cannot be modified once {@link
+ * ANeuralNetworksCompilation_start} has been called on it.</p>
+ *
+ * <p>It is the application's responsibility to make sure that only one thread
+ * modifies a compilation at a given time. It is however safe for more than one
+ * thread to use {@link ANeuralNetworksCompilation_wait} at the same time.
+ * It is also safe for multiple threads to use a compilation object once
+ * {@link ANeuralNetworksCompilation_wait} has completed.</p>
+ *
+ * <p>It is also the application's responsibility to ensure that there are no
+ * other uses of the compilation after calling {@link
+ * ANeuralNetworksCompilation_free}. This includes any execution object created
+ * using the compilation.</p>
+ */
+typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation;
+
+/**
+ * ANeuralNetworksExecution is an opaque type that can be used to apply a
+ * machine learning model to a set of inputs.
+ *
+ * <p>To use:<ul>
+ * <li>Create a new execution instance by calling the
+ * {@link ANeuralNetworksExecution_create} function.</li>
+ * <li>Associate data to the model inputs with
+ * {@link ANeuralNetworksExecution_setInput} or
+ * {@link ANeuralNetworksExecution_setInputFromMemory}.</li>
+ * <li>Associate output buffers to the model outputs with
+ * {@link ANeuralNetworksExecution_setOutput} or
+ * {@link ANeuralNetworksExecution_setOutputFromMemory}.</li>
+ * <li>Apply the model with {@link
+ * ANeuralNetworksExecution_startCompute}.</li> <li>Wait for the execution to
+ * complete with {@link ANeuralNetworksExecution_wait}.</li> <li>Destroy the
+ * execution with
+ * {@link ANeuralNetworksExecution_free}.</li></ul></p>
+ *
+ * <p>An execution cannot be modified once {@link
+ * ANeuralNetworksExecution_start} has been called on it.</p>
+ *
+ * <p>An execution can be applied to a model with
+ * {@link ANeuralNetworksExecution_startCompute} only once. Create new
+ * executions to do new evaluations of the model.</p>
+ *
+ * <p>It is the application's responsibility to make sure that only one thread
+ * modifies an execution at a given time. It is however safe for more than one
+ * thread to use {@link ANeuralNetworksExecution_wait} at the same time.</p>
+ *
+ * <p>It is also the application's responsibility to ensure that there are no
+ * other uses of the request after calling {@link
+ * ANeuralNetworksRequest_free}.</p>
+ */
+typedef struct ANeuralNetworksExecution ANeuralNetworksExecution;
+
+/**
+ * ANeuralNetworksOperandType describes the type of an operand.
+ * This structure is used to describe both scalars and tensors.
+ */
+typedef struct ANeuralNetworksOperandType {
+ /** The data type, e.g ANEURALNETWORKS_INT8. */
+ int32_t type;
+ /** The number of dimensions. It should be 0 for scalars. */
+ uint32_t dimensionCount;
+ /** The dimensions of the tensor. It should be nullptr for scalars. */
+ const uint32_t* dimensions;
+ /** These two fields are only used for quantized tensors.
+ * They should be zero for scalars and non-fixed point tensors.
+ * The dequantized value of each entry is (value - offset) * scale.
+ */
+ float scale;
+ int32_t zeroPoint;
+} ANeuralNetworksOperandType;
+
+/**
+ * ANeuralNetworksEvent is an opaque type that represents an event
+ * that will be signaled once an execution completes.
+ */
+typedef struct ANeuralNetworksEvent ANeuralNetworksEvent;
+
+typedef int32_t ANeuralNetworksOperationType;
+
+// nn api function types
+
+typedef int (*ANeuralNetworksMemory_createFromFd_fn)(
+ size_t size, int protect, int fd, size_t offset,
+ ANeuralNetworksMemory** memory);
+
+typedef void (*ANeuralNetworksMemory_free_fn)(ANeuralNetworksMemory* memory);
+
+typedef int (*ANeuralNetworksModel_create_fn)(ANeuralNetworksModel** model);
+
+typedef int (*ANeuralNetworksModel_finish_fn)(ANeuralNetworksModel* model);
+
+typedef void (*ANeuralNetworksModel_free_fn)(ANeuralNetworksModel* model);
+
+typedef int (*ANeuralNetworksCompilation_create_fn)(
+ ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation);
+
+typedef void (*ANeuralNetworksCompilation_free_fn)(
+ ANeuralNetworksCompilation* compilation);
+
+typedef int (*ANeuralNetworksCompilation_setPreference_fn)(
+ ANeuralNetworksCompilation* compilation, int32_t preference);
+
+typedef int (*ANeuralNetworksCompilation_finish_fn)(
+ ANeuralNetworksCompilation* compilation);
+
+typedef int (*ANeuralNetworksModel_addOperand_fn)(
+ ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type);
+
+typedef int (*ANeuralNetworksModel_setOperandValue_fn)(
+ ANeuralNetworksModel* model, int32_t index, const void* buffer,
+ size_t length);
+
+typedef int (*ANeuralNetworksModel_setOperandValueFromMemory_fn)(
+ ANeuralNetworksModel* model, int32_t index,
+ const ANeuralNetworksMemory* memory, size_t offset, size_t length);
+
+typedef int (*ANeuralNetworksModel_addOperation_fn)(
+ ANeuralNetworksModel* model, ANeuralNetworksOperationType type,
+ uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
+ const uint32_t* outputs);
+
+typedef int (*ANeuralNetworksModel_identifyInputsAndOutputs_fn)(
+ ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
+ uint32_t outputCount, const uint32_t* outputs);
+
+typedef int (*ANeuralNetworksExecution_create_fn)(
+ ANeuralNetworksCompilation* compilation,
+ ANeuralNetworksExecution** execution);
+
+typedef void (*ANeuralNetworksExecution_free_fn)(
+ ANeuralNetworksExecution* execution);
+
+typedef int (*ANeuralNetworksExecution_setInput_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const void* buffer, size_t length);
+
+typedef int (*ANeuralNetworksExecution_setInputFromMemory_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length);
+
+typedef int (*ANeuralNetworksExecution_setOutput_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, void* buffer, size_t length);
+
+typedef int (*ANeuralNetworksExecution_setOutputFromMemory_fn)(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length);
+
+typedef int (*ANeuralNetworksExecution_startCompute_fn)(
+ ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event);
+
+typedef int (*ANeuralNetworksEvent_wait_fn)(ANeuralNetworksEvent* event);
+
+typedef void (*ANeuralNetworksEvent_free_fn)(ANeuralNetworksEvent* event);
+
+/**
+ * Creates a shared memory object from a file descriptor.
+ *
+ * The shared memory is backed by a file descriptor via mmap.
+ * See {@link ANeuralNetworksMemory} for a description on how to use
+ * this shared memory.
+ *
+ * @param size The requested size in bytes.
+ * Must not be larger than the file size.
+ * @param prot The desired memory protection for the mapping.
+ * It is either PROT_NONE or the bitwise OR of one or
+ * more of the following flags: PROT_READ, PROT_WRITE.
+ * @param fd The requested file descriptor.
+ * The file descriptor has to be mmap-able. The file
+ * descriptor will be duplicated.
+ * @param offset The offset to the beginning of the file of the area to map.
+ * The offset has to be aligned to a page size.
+ * @param memory The memory object to be created.
+ * Set to NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if the request completed normally.
+ */
+inline int ANeuralNetworksMemory_createFromFd(size_t size, int protect, int fd,
+ size_t offset,
+ ANeuralNetworksMemory** memory) {
+ LOAD_FUNCTION(ANeuralNetworksMemory_createFromFd);
+ EXECUTE_FUNCTION_RETURN(size, protect, fd, offset, memory);
+}
+
+/**
+ * Delete a memory object.
+ *
+ * Destroys the object used by the run time to keep track of the memory.
+ * This will free the underlying actual memory if no other code has open
+ * handles to this memory.
+ *
+ * @param memory The memory object to be freed.
+ */
+inline void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
+ LOAD_FUNCTION(ANeuralNetworksMemory_free);
+ EXECUTE_FUNCTION(memory);
+}
+
+/**
+ * Create an empty {@link ANeuralNetworksModel}.
+ *
+ * <p>This only creates the object. Computation is performed once
+ * {@link ANeuralNetworksExecution_startCompute} is invoked.
+ *
+ * The model should be constructed with calls to
+ * {@link ANeuralNetworksModel_addOperation} and
+ * {@link ANeuralNetworksModel_addOperand}
+ *
+ * <p>{@link ANeuralNetworksModel_finish} should be called once the model
+ * has been fully constructed.</p>
+ *
+ * <p>{@link ANeuralNetworksModel_free} should be called once the model
+ * is no longer needed.</p>
+ *
+ * @param model The {@link ANeuralNetworksModel} to be created.
+ * Set to NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
+ LOAD_FUNCTION(ANeuralNetworksModel_create);
+ EXECUTE_FUNCTION_RETURN(model);
+}
+
+/**
+ * Destroy a model.
+ *
+ * The model need not have been finished by a call to
+ * {@link ANeuralNetworksModel_finish}.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be destroyed. Passing NULL is acceptable and
+ * results in no operation.
+ */
+inline void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
+ LOAD_FUNCTION(ANeuralNetworksModel_free);
+ EXECUTE_FUNCTION(model);
+}
+
+/**
+ * Indicate that we have finished modifying a model. Required before
+ * calling {@link ANeuralNetworksCompilation_compile}.
+ *
+ * An application is responsible to make sure that no other thread uses
+ * the model at the same time.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be finished.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
+ LOAD_FUNCTION(ANeuralNetworksModel_finish);
+ EXECUTE_FUNCTION_RETURN(model);
+}
+
+/**
+ * Add an operand to a model.
+ *
+ * The order in which the operands are added is important. The first one added
+ * to a model will have the index value 0, the second 1, etc. These indexes are
+ * used as operand identifiers in {@link ANeuralNetworksModel_addOperation},
+ * {@link ANeuralNetworksExecution_setInput},
+ * {@link ANeuralNetworksExecution_setInputFromMemory},
+ * {@link ANeuralNetworksExecution_setOutput},
+ * {@link ANeuralNetworksExecution_setOutputFromMemory} and
+ * {@link ANeuralNetworksExecution_setOperandValue}.
+ *
+ * To build a model that can accomodate inputs of various sizes, as you may want
+ * to do for a CNN, set the size of the dimensions that will vary at run time to
+ * 0. If you do so, provide the full dimensions when calling
+ * {@link ANeuralNetworksExecution_setInput} or {@link
+ * ANeuralNetworksExecution_setInputFromMemory}.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be modified.
+ * @param type The {@link ANeuralNetworksOperandType} that describes the shape
+ * of the operand.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_addOperand(
+ ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) {
+ LOAD_FUNCTION(ANeuralNetworksModel_addOperand);
+ EXECUTE_FUNCTION_RETURN(model, type);
+}
+
+/**
+ * Sets an operand to a constant value.
+ *
+ * For scalar values, the content of buffer is copied into the model.
+ *
+ * For tensor values, a pointer to the buffer is stored within the model.
+ * The application is responsible for not changing the content of this region
+ * until all executions using this model have completed. As the data may
+ * be copied during processing, modifying the data after this call yields
+ * undefined results.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be modified.
+ * @param index The index of the model operand we're setting.
+ * @param buffer A pointer to the data to use.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model,
+ int32_t index,
+ const void* buffer,
+ size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksModel_setOperandValue);
+ EXECUTE_FUNCTION_RETURN(model, index, buffer, length);
+}
+
+/**
+ * Sets an operand to a value stored in a memory object.
+ *
+ * The content of the memory is not copied. A reference to that memory is stored
+ * inside the model. The application is responsible for not changing the content
+ * of the memory region until all executions using this model have completed.
+ * As the data may be copied during processing, modifying the data after this
+ * call yields undefined results.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @param model The model to be modified.
+ * @param index The index of the model operand we're setting.
+ * @param buffer A pointer to the data to use.
+ * @param memory The memory containing the data.
+ * @param offset This specifies the location of the data within the memory.
+ * The offset is in bytes from the start of memory.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_setOperandValueFromMemory(
+ ANeuralNetworksModel* model, int32_t index,
+ const ANeuralNetworksMemory* memory, size_t offset, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksModel_setOperandValueFromMemory);
+ EXECUTE_FUNCTION_RETURN(model, index, memory, offset, length);
+}
+
+/**
+ * Add an operation to a model.
+ *
+ * @param model The model to be modified.
+ * @param type The type of the operation.
+ * @param inputCount The number of entries in the inputs array.
+ * @param inputs An array of indexes identifying each operand.
+ * @param outputCount The number of entries in the outputs array.
+ * @param outputs An array of indexes identifying each operand.
+ *
+ * The operands specified by inputs and outputs must have been
+ * previously added by calls to {@link ANeuralNetworksModel_addOperand}.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
+ ANeuralNetworksOperationType type,
+ uint32_t inputCount,
+ const uint32_t* inputs,
+ uint32_t outputCount,
+ const uint32_t* outputs) {
+ LOAD_FUNCTION(ANeuralNetworksModel_addOperation);
+ EXECUTE_FUNCTION_RETURN(model, type, inputCount, inputs, outputCount,
+ outputs);
+}
+
+/**
+ * Specfifies which operands will be the model's inputs and outputs.
+ *
+ * An operand cannot be used for both input and output. Doing so will
+ * return an error.
+ *
+ * @param model The model to be modified.
+ * @param inputCount The number of entries in the inputs array.
+ * @param inputs An array of indexes identifying the input operands.
+ * @param outputCount The number of entries in the outputs array.
+ * @param outputs An array of indexes identifying the output operands.
+ *
+ * The operands specified by inputs and outputs must have been
+ * previously added by calls to {@link ANeuralNetworksModel_addOperand}.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ */
+inline int ANeuralNetworksModel_identifyInputsAndOutputs(
+ ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
+ uint32_t outputCount, const uint32_t* outputs) {
+ LOAD_FUNCTION(ANeuralNetworksModel_identifyInputsAndOutputs);
+ EXECUTE_FUNCTION_RETURN(model, inputCount, inputs, outputCount, outputs);
+}
+
+/**
+ * Create a {@link ANeuralNetworksCompilation} to compile the given model.
+ * This only creates the object. Compilation is only performed once
+ * {@link ANeuralNetworksCompilation_start} is invoked.
+ *
+ * <p>The provided model must outlive the compilation.</p>
+ *
+ * The model must already have been finished by a call to
+ * {@link ANeuralNetworksModel_finish}.
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @param model The {@link ANeuralNetworksModel} to be compiled.
+ * @param compilation The newly created object or NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA
+ * if the model is invalid.
+ */
+inline int ANeuralNetworksCompilation_create(
+ ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_create);
+ EXECUTE_FUNCTION_RETURN(model, compilation);
+}
+
+/**
+ * Destroy a compilation.
+ *
+ * <p>If called on a compilation for which
+ * {@link ANeuralNetworksCompilation_start} has been called, the
+ * function will return immediately but will mark the compilation to be deleted
+ * once the compilation completes. The {@link ANeuralNetworksCompilation_wait}
+ * will return ERROR_DELETED.
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @param compilation The compilation to be destroyed. Passing NULL is
+ * acceptable and results in no operation.
+ */
+inline void ANeuralNetworksCompilation_free(
+ ANeuralNetworksCompilation* compilation) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_free);
+ EXECUTE_FUNCTION(compilation);
+}
+
+/**
+ * Sets the execution preference.
+ *
+ * <p>Provides guidance to the runtime when trade-offs are possible.</p>
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @param compilation The compilation to be modified.
+ * @param preference Either {@link PREFER_LOW_POWER},
+ * {@link PREFER_SINGLE_FAST_ANSWER}, or
+ * {@link PREFER_SUSTAINED_SPEED}.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksCompilation_setPreference(
+ ANeuralNetworksCompilation* compilation, int32_t preference) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_setPreference);
+ EXECUTE_FUNCTION_RETURN(compilation, preference);
+}
+
+/**
+ * Waits until the compilation completes.
+ *
+ * More than one thread can wait on a compilation. When the compilation
+ * completes, all threads will be released.
+ *
+ * See {@link ANeuralNetworksCompilation} for information on multithreaded
+ * usage.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if the compilation completed normally.
+ */
+inline int ANeuralNetworksCompilation_finish(
+ ANeuralNetworksCompilation* compilation) {
+ LOAD_FUNCTION(ANeuralNetworksCompilation_finish);
+ EXECUTE_FUNCTION_RETURN(compilation);
+}
+/**
+ * Create a {@link ANeuralNetworksExecution} to apply the given compilation.
+ * This only creates the object. Computation is only performed once
+ * {@link ANeuralNetworksExecution_startCompute} is invoked.
+ *
+ * <p>The provided compilation must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated.
+ * @param execution The newly created object or NULL if unsuccessful.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA
+ * if the compilation is invalid.
+ */
+inline int ANeuralNetworksExecution_create(
+ ANeuralNetworksCompilation* compilation,
+ ANeuralNetworksExecution** execution) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_create);
+ EXECUTE_FUNCTION_RETURN(compilation, execution);
+}
+
+/**
+ * Destroy an execution.
+ *
+ * <p>If called on an execution for which
+ * {@link ANeuralNetworksExecution_startCompute} has been called, the
+ * function will return immediately but will mark the execution to be deleted
+ * once the computation completes. The {link ANeuralNetworksExecution_wait}
+ * will return ANEURALNETWORKS_ERROR_DELETED.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be destroyed. Passing NULL is acceptable
+ * and results in no operation.
+ */
+inline void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_free);
+ EXECUTE_FUNCTION(execution);
+}
+
+/**
+ * Associate a user buffer with an input of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided buffer must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the input argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This should be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other properties of the type must be the same as
+ * specified in the model. If the type is the same as specified
+ * when the model was built, NULL can be passed.
+ * @param buffer The buffer containing the data.
+ * @param length The length in bytes of the buffer.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the input.
+ */
+inline int ANeuralNetworksExecution_setInput(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const void* buffer, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setInput);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length);
+}
+
+/**
+ * Associate part of a memory object with an input of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided memory must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the input argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This can be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other values must be the same as specified in the
+ * model. If the type is the same as specified when the model
+ * was built, NULL can be passed.
+ * @param memory The memory containing the data.
+ * @param offset This specifies the location of the data whithin the memory.
+ * The offset is in bytes from the start of memory.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the input.
+ */
+inline int ANeuralNetworksExecution_setInputFromMemory(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setInputFromMemory);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length);
+}
+
+/**
+ * Associate a user buffer with an output of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided buffer must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the output argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This can be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other values must be the same as specified in the
+ * model. If the type is the same as specified when the model
+ * was built, NULL can be passed.
+ * @param buffer The buffer where the data is to be written.
+ * @param length The length in bytes of the buffer.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the output.
+ */
+inline int ANeuralNetworksExecution_setOutput(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, void* buffer, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setOutput);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length);
+}
+
+/**
+ * Associate part of a memory object with an output of the model of the
+ * {@link ANeuralNetworksExecution}.
+ *
+ * <p>The provided memory must outlive the execution.</p>
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be modified.
+ * @param index The index of the output argument we are setting. It is
+ * an index into the lists passed to
+ * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not
+ * the index associated with {@link
+ * ANeuralNetworksModel_addOperand}.
+ * @param type The type of the operand. This can be used to specify the
+ * dimensions that were set to 0 when the operand was added to the
+ * model. All other values must be the same as specified in the
+ * model. If the type is the same as specified when the model
+ * was built, NULL can be passed.
+ * @param memory The memory where the data is to be stored.
+ * @param offset This specifies the location of the data whithin the memory.
+ * The offset is in bytes from the start of memory.
+ * @param length The length in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if
+ * the name is not recognized or the buffer is too small for the output.
+ */
+inline int ANeuralNetworksExecution_setOutputFromMemory(
+ ANeuralNetworksExecution* execution, int32_t index,
+ const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory,
+ size_t offset, size_t length) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_setOutputFromMemory);
+ EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length);
+}
+
+/**
+ * Schedule evaluation of the execution.
+ *
+ * <p>Schedules evaluation of the execution. Once the model has been
+ * applied and the outputs are ready to be consumed, the execution will be
+ * signaled. Use {@link ANeuralNetworksExecution_wait} to wait for that signal.
+ * </p>
+ *
+ * Multiple executions can be scheduled and evaluated concurrently, and
+ * compilations can be performed concurrently with executions. The runtime makes
+ * no guarantee on the ordering of the completion of compilations and
+ * executions. If it's important to the application, the application should
+ * enforce the ordering by using {@link ANeuralNetworksCompilation_wait} and
+ * {@link ANeuralNetworksExecution_wait}.
+ *
+ * ANeuralNetworksExecution_wait must be called to recuperate the resources used
+ * by the execution.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @param execution The execution to be scheduled and executed.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+inline int ANeuralNetworksExecution_startCompute(
+ ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) {
+ LOAD_FUNCTION(ANeuralNetworksExecution_startCompute);
+ EXECUTE_FUNCTION_RETURN(execution, event);
+}
+
+/**
+ * Waits until the execution completes.
+ *
+ * More than one thread can wait on an event. When the execution completes,
+ * all threads will be released.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally.
+ */
+inline int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
+ LOAD_FUNCTION(ANeuralNetworksEvent_wait);
+ EXECUTE_FUNCTION_RETURN(event);
+}
+
+/**
+ * Destroys the event.
+ *
+ * See {@link ANeuralNetworksExecution} for information on multithreaded usage.
+ */
+inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
+ LOAD_FUNCTION(ANeuralNetworksEvent_free);
+ EXECUTE_FUNCTION(event);
+}
+
+/**/
+
+#endif // NN_API_SHIM_H0
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
new file mode 100644
index 0000000000..6a199cc840
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -0,0 +1,386 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+
+namespace tflite {
+
+// TODO(aselle): FATAL leaves resources hanging.
+void FATAL(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ vfprintf(stderr, format, args);
+ va_end(args);
+ fflush(stderr);
+ exit(1);
+}
+
+// TODO(aselle): Change the error model to use status codes.
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+#define CHECK_NN(x) \
+ if (x != ANEURALNETWORKS_NO_ERROR) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+NNAPIAllocation::NNAPIAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : MMAPAllocation(filename, error_reporter) {
+ if (mmapped_buffer_ != MAP_FAILED)
+ CHECK_NN(ANeuralNetworksMemory_createFromFd(buffer_size_bytes_, PROT_READ,
+ mmap_fd_, 0, &handle_));
+}
+
+NNAPIAllocation::~NNAPIAllocation() {
+ if (handle_) {
+ ANeuralNetworksMemory_free(handle_);
+ }
+}
+
+NNAPIDelegate::~NNAPIDelegate() {
+ if (nn_model_) {
+ ANeuralNetworksModel_free(nn_model_);
+ nn_model_ = nullptr;
+ // TODO(aselle): Is this thread-safe and callable multiple times?
+ }
+ // ANeuralNetworksShutdown();
+}
+
+// Adds the tensors of the interpreter to the NN API model.
+// Returns the number of operands added.
+uint32_t addTensorOperands(tflite::Interpreter* interpreter,
+ ANeuralNetworksModel* nn_model) {
+ uint32_t next_id = 0;
+ for (size_t i = 0; i < interpreter->tensors_size(); i++) {
+ int32_t nn_type = 0;
+ float scale = 1.0f;
+ int32_t zeroPoint = 0;
+ TfLiteTensor* tensor = interpreter->tensor(i);
+ switch (tensor->type) {
+ case kTfLiteNoType:
+ // Tensors added during initialization of Ops don't have a type yet and
+ // should not be registered with the NNAPI.
+ continue;
+ case kTfLiteFloat32:
+ nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
+ break;
+ case kTfLiteUInt8:
+ nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ case kTfLiteInt32:
+ nn_type = ANEURALNETWORKS_TENSOR_INT32;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ default:
+ FATAL("Unsupported type.");
+ }
+ // TODO(aselle): Note, many of these are intermediate results. Do I need
+ // to ever specify these sizes. I am currently below doing setValue
+ // on all of them, but I shouldn't in the future.
+ // Answer(jeanluc): If all the operators can set the dimension correctly,
+ // you won't need to.
+ ANeuralNetworksOperandType operand_type{
+ nn_type, static_cast<uint32_t>(tensor->dims->size),
+ reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+
+ // TODO(aselle): Based on Michael's suggestion, limiting this to read
+ // only memory
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ if (const NNAPIAllocation* alloc = dynamic_cast<const NNAPIAllocation*>(
+ static_cast<const Allocation*>(tensor->allocation))) {
+ CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory(
+ nn_model, i, alloc->memory(), alloc->offset(tensor->data.raw),
+ tensor->bytes));
+ } else {
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ nn_model, i, tensor->data.raw, tensor->bytes));
+ }
+ }
+ ++next_id;
+ }
+ return next_id;
+}
+
+// Adds the operations and their parameters to the NN API model.
+// 'next-id' is the operand ID of the next operand of the model.
+void AddOpsAndParams(tflite::Interpreter* interpreter,
+ ANeuralNetworksModel* nn_model, uint32_t next_id) {
+ for (size_t i = 0; i < interpreter->nodes_size(); i++) {
+ const auto* node_and_registration = interpreter->node_and_registration(i);
+ const TfLiteNode& node = node_and_registration->first;
+ const TfLiteRegistration& registration = node_and_registration->second;
+ tflite::BuiltinOperator builtin =
+ static_cast<tflite::BuiltinOperator>(registration.builtin_code);
+
+ // Add the parameters.
+ std::vector<uint32_t> augmented_inputs(
+ node.inputs->data, node.inputs->data + node.inputs->size);
+
+ auto add_scalar_int32 = [&nn_model, &augmented_inputs,
+ &next_id](int value) {
+ ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value,
+ sizeof(int32_t)))
+ augmented_inputs.push_back(next_id++);
+ };
+
+ auto add_scalar_float32 = [&nn_model, &augmented_inputs,
+ &next_id](float value) {
+ ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_FLOAT32};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value,
+ sizeof(float)))
+ augmented_inputs.push_back(next_id++);
+ };
+
+ auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
+
+ auto add_pooling_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
+ add_scalar_int32(builtin->padding);
+ add_scalar_int32(builtin->stride_width);
+ add_scalar_int32(builtin->stride_height);
+ add_scalar_int32(builtin->filter_width);
+ add_scalar_int32(builtin->filter_height);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_convolution_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteConvParams*>(data);
+ add_scalar_int32(builtin->padding);
+ add_scalar_int32(builtin->stride_width);
+ add_scalar_int32(builtin->stride_height);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_depthwise_conv_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(data);
+ add_scalar_int32(builtin->padding);
+ add_scalar_int32(builtin->stride_width);
+ add_scalar_int32(builtin->stride_height);
+ add_scalar_int32(builtin->depth_multiplier);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_fully_connected_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
+ add_scalar_int32(builtin->activation);
+ };
+
+ auto add_concatenation_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(data);
+ add_scalar_int32(builtin->axis);
+ if (builtin->activation != kTfLiteActNone) {
+ FATAL("Concatenation does not support fused activation in NNAPI");
+ }
+ };
+
+ auto add_softmax_params = [&add_scalar_float32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>(data);
+ add_scalar_float32(builtin->beta);
+ };
+
+#if 0
+ auto add_reshape_params = [&](void* data) {
+ auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data);
+ uint32_t tensor_size_shape = builtin->num_dimensions;
+ ANeuralNetworksOperandType operand_type{
+ ANEURALNETWORKS_TENSOR_INT32,
+ {static_cast<uint32_t>(1),
+ reinterpret_cast<uint32_t*>(&tensor_size_shape)},
+ 0,
+ 0};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ nn_model, next_id, builtin->shape,
+ sizeof(int) * builtin->num_dimensions));
+ augmented_inputs.push_back(next_id++);
+ };
+#endif
+
+ ANeuralNetworksOperationType nn_op_type;
+ switch (builtin) {
+ case tflite::BuiltinOperator_ADD:
+ nn_op_type = ANEURALNETWORKS_ADD;
+ add_add_params();
+ break;
+ case tflite::BuiltinOperator_AVERAGE_POOL_2D:
+ add_pooling_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
+ break;
+ case tflite::BuiltinOperator_MAX_POOL_2D:
+ add_pooling_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_MAX_POOL_2D;
+ break;
+ case tflite::BuiltinOperator_L2_POOL_2D:
+ add_pooling_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
+ break;
+ case tflite::BuiltinOperator_CONV_2D:
+ add_convolution_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_CONV_2D;
+ break;
+ case tflite::BuiltinOperator_RELU:
+ nn_op_type = ANEURALNETWORKS_RELU;
+ break;
+ case tflite::BuiltinOperator_RELU6:
+ nn_op_type = ANEURALNETWORKS_RELU6;
+ break;
+ case tflite::BuiltinOperator_TANH:
+ nn_op_type = ANEURALNETWORKS_TANH;
+ break;
+ case tflite::BuiltinOperator_LOGISTIC:
+ nn_op_type = ANEURALNETWORKS_LOGISTIC;
+ break;
+ case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
+ add_depthwise_conv_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
+ break;
+ case tflite::BuiltinOperator_CONCATENATION:
+ add_concatenation_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_CONCATENATION;
+ break;
+ case tflite::BuiltinOperator_SOFTMAX:
+ add_softmax_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_SOFTMAX;
+ break;
+ case tflite::BuiltinOperator_FULLY_CONNECTED:
+ add_fully_connected_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
+ break;
+ case tflite::BuiltinOperator_RESHAPE:
+ nn_op_type = ANEURALNETWORKS_RESHAPE;
+ // add_reshape_params(node.builtin_data);
+ break;
+ case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
+ case tflite::BuiltinOperator_LSH_PROJECTION:
+ case tflite::BuiltinOperator_SVDF:
+ case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
+ case tflite::BuiltinOperator_RNN:
+ case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
+ case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
+ case tflite::BuiltinOperator_LSTM:
+ case tflite::BuiltinOperator_L2_NORMALIZATION:
+ case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
+ case tflite::BuiltinOperator_MUL:
+ case tflite::BuiltinOperator_RESIZE_BILINEAR:
+ case tflite::BuiltinOperator_CALL:
+ case tflite::BuiltinOperator_SKIP_GRAM:
+ case tflite::BuiltinOperator_RELU1:
+ case tflite::BuiltinOperator_SPACE_TO_DEPTH:
+ FATAL("Op code %d is currently not delegated to NNAPI", builtin);
+ nn_op_type = -1; // set to invalid
+ break;
+ case tflite::BuiltinOperator_CUSTOM:
+ FATAL("Custom operations are not supported when using NNAPI.");
+ nn_op_type = -1; // set to invalid
+ break;
+ }
+
+ // Add the operation.
+ CHECK_NN(ANeuralNetworksModel_addOperation(
+ nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
+ augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
+ reinterpret_cast<uint32_t*>(node.outputs->data)));
+ }
+}
+
+TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
+ // TODO(aselle): This is not correct. need to handle resize invalidation.
+ if (nn_model_ && nn_compiled_model_) return kTfLiteOk;
+
+ if (!nn_model_) {
+ CHECK_NN(ANeuralNetworksModel_create(&nn_model_));
+
+ uint32_t next_id = addTensorOperands(interpreter, nn_model_);
+ AddOpsAndParams(interpreter, nn_model_, next_id);
+ CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
+ nn_model_, static_cast<uint32_t>(interpreter->inputs().size()),
+ reinterpret_cast<const uint32_t*>(interpreter->inputs().data()),
+ static_cast<uint32_t>(interpreter->outputs().size()),
+ reinterpret_cast<const uint32_t*>(interpreter->outputs().data())));
+ CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
+ }
+ if (!nn_compiled_model_) {
+ CHECK_NN(ANeuralNetworksCompilation_create(nn_model_, &nn_compiled_model_));
+ CHECK_NN(ANeuralNetworksCompilation_finish(nn_compiled_model_));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
+ if (!nn_model_) {
+ TF_LITE_ENSURE_STATUS(BuildGraph(interpreter));
+ }
+
+ ANeuralNetworksExecution* execution = nullptr;
+ CHECK_NN(ANeuralNetworksExecution_create(nn_compiled_model_, &execution));
+
+ // Currently perform deep copy of input buffer
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input = interpreter->inputs()[i];
+ // TODO(aselle): Is this what we want or do we want input instead?
+ // TODO(aselle): This should be called setInputValue maybe to be cons.
+ TfLiteTensor* tensor = interpreter->tensor(input);
+ CHECK_NN(ANeuralNetworksExecution_setInput(
+ execution, i, nullptr, tensor->data.raw, tensor->bytes));
+ }
+ // Tell nn api where to place final data.
+ for (size_t i = 0; i < interpreter->outputs().size(); i++) {
+ int output = interpreter->outputs()[i];
+ TfLiteTensor* tensor = interpreter->tensor(output);
+ CHECK_NN(ANeuralNetworksExecution_setOutput(
+ execution, i, nullptr, tensor->data.raw, tensor->bytes));
+ }
+ // Currently use blocking compute.
+ ANeuralNetworksEvent* event = nullptr;
+ CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event));
+ CHECK_NN(ANeuralNetworksEvent_wait(event));
+ ANeuralNetworksEvent_free(event);
+ ANeuralNetworksExecution_free(execution);
+
+#if 0
+ printf("From the NN API:\n");
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ printf(" %f", *p);
+ }
+ printf("\n");
+ }
+#endif
+
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
new file mode 100644
index 0000000000..f29aa9e18e
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+
+class ANeuralNetworsModel;
+
+namespace tflite {
+
+class NNAPIAllocation : public MMAPAllocation {
+ public:
+ NNAPIAllocation(const char* filename, ErrorReporter* error_reporter);
+ ~NNAPIAllocation();
+
+ size_t offset(const void* ptr) const {
+ auto signed_offset = reinterpret_cast<const uint8_t*>(ptr) -
+ reinterpret_cast<const uint8_t*>(mmapped_buffer_);
+
+ return static_cast<size_t>(signed_offset);
+ }
+
+ ANeuralNetworksMemory* memory() const { return handle_; }
+ bool valid() const override { return handle_ != nullptr; }
+
+ private:
+ mutable ANeuralNetworksMemory* handle_ = nullptr;
+};
+
+class NNAPIDelegate {
+ public:
+ ~NNAPIDelegate();
+
+ // Convert a tflite graph to NNAPI
+ TfLiteStatus BuildGraph(Interpreter* interpreter);
+
+ // Run
+ TfLiteStatus Invoke(Interpreter* interpreter);
+
+ private:
+ // The NN API model handle
+ ANeuralNetworksModel* nn_model_ = nullptr;
+ // The NN API compilation handle
+ ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
new file mode 100644
index 0000000000..1f762e6688
--- /dev/null
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/optional_debug_tools.h"
+
+namespace tflite {
+
+void PrintIntVector(const std::vector<int>& v) {
+ for (const auto& it : v) {
+ printf(" %d", it);
+ }
+ printf("\n");
+}
+
+void PrintTfLiteIntVector(const TfLiteIntArray* v) {
+ if (!v) {
+ printf(" (null)");
+ return;
+ }
+ for (int k = 0; k < v->size; k++) {
+ printf(" %d", v->data[k]);
+ }
+ printf("\n");
+}
+
+const char* TensorTypeName(TfLiteType type) {
+ switch (type) {
+ case kTfLiteNoType:
+ return "kTfLiteNoType";
+ case kTfLiteFloat32:
+ return "kTfLiteFloat32";
+ case kTfLiteInt32:
+ return "kTfLiteInt32";
+ case kTfLiteUInt8:
+ return "kTfLiteUInt8";
+ case kTfLiteInt64:
+ return "kTfLiteInt64";
+ case kTfLiteString:
+ return "kTfLiteString";
+ }
+ return "(invalid)";
+}
+
+const char* AllocTypeName(TfLiteAllocationType type) {
+ switch (type) {
+ case kTfLiteMemNone:
+ return "kTfLiteMemNone";
+ case kTfLiteMmapRo:
+ return "kTfLiteMmapRo";
+ case kTfLiteDynamic:
+ return "kTfLiteDynamic";
+ case kTfLiteArenaRw:
+ return "kTfLiteArenaRw";
+ case kTfLiteArenaRwPersistent:
+ return "kTfLiteArenaRwPersistent";
+ }
+ return "(invalid)";
+}
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+void PrintInterpreterState(Interpreter* interpreter) {
+ printf("Interpreter has %d tensors and %d nodes\n",
+ interpreter->tensors_size(), interpreter->nodes_size());
+ printf("Inputs:");
+ PrintIntVector(interpreter->inputs());
+ printf("Outputs:");
+ PrintIntVector(interpreter->outputs());
+ printf("\n");
+ for (int tensor_index = 0; tensor_index < interpreter->tensors_size();
+ tensor_index++) {
+ TfLiteTensor* tensor = interpreter->tensor(tensor_index);
+ printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
+ TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type),
+ tensor->bytes, float(tensor->bytes) / float(1 << 20));
+ PrintTfLiteIntVector(tensor->dims);
+ printf("\n");
+ }
+
+ for (int node_index = 0; node_index < interpreter->nodes_size();
+ node_index++) {
+ const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg =
+ interpreter->node_and_registration(node_index);
+ const TfLiteNode& node = node_and_reg->first;
+ const TfLiteRegistration& reg = node_and_reg->second;
+ printf("Node %3d Operator Builtin Code %3d\n", node_index,
+ reg.builtin_code);
+ printf(" Inputs:");
+ PrintTfLiteIntVector(node.inputs);
+ printf(" Outputs:");
+ PrintTfLiteIntVector(node.outputs);
+ }
+}
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h
new file mode 100644
index 0000000000..54d4876095
--- /dev/null
+++ b/tensorflow/contrib/lite/optional_debug_tools.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Optional debugging functionality. For small sized binaries, these are not
+// needed.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+void PrintInterpreterState(Interpreter* interpreter);
+
+// Prints a dump of what tensors and what nodes are in the interpreter.
+TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
new file mode 100644
index 0000000000..b4aa032ff8
--- /dev/null
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -0,0 +1,46 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "lite",
+ srcs = ["lite.py"],
+ # data = [
+ # "//tensorflow/contrib/lite/toco/python:toco_from_protos",
+ # ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:model_flags_proto_py",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
+ "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "lite_test",
+ srcs = ["lite_test.py"],
+ deps = [
+ ":lite",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
new file mode 100644
index 0000000000..5e8edbb937
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Lite tooling helper functionality.
+
+EXPERIMENTAL: APIs here are unstable and likely to change without notice.
+
+@@toco_convert
+@@toco_convert_protos
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import subprocess
+import tempfile
+
+from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco.python.tensorflow_wrap_toco import TocoConvert as _toco_convert_protos
+from tensorflow.python.framework import dtypes as _dtypes
+# from tensorflow.python.platform import
+# resource_loader as _resource_loader
+
+# Enum types from the protobuf promoted to the API
+FLOAT = _toco_flags_pb2.FLOAT
+INT32 = _toco_flags_pb2.INT32
+INT64 = _toco_flags_pb2.INT64
+STRING = _toco_flags_pb2.STRING
+QUANTIZED_UINT8 = _toco_flags_pb2.QUANTIZED_UINT8
+TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
+TFLITE = _toco_flags_pb2.TFLITE
+GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
+
+# Currently the default mode of operation is to shell to another python process
+# to protect against crashes.
+EXPERIMENTAL_USE_TOCO_API_DIRECTLY = True
+
+# Find the toco_from_protos binary using the resource loader if using from
+# bazel, otherwise we are in a pip where console_scripts already has
+# the toco_from_protos tool.
+# toco_from_proto_bin = _resource_loader.get_path_to_datafile(
+# "../toco/python/toco_from_protos")
+# if not os.path.exists(toco_from_proto_bin):
+# toco_from_proto_bin = "toco_from_protos"
+
+
+def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
+ """Convert `input_data_str` according to model and toco parameters.
+
+ Unless you know what you are doing consider using
+ the more friendly @{tf.contrib.lite.toco_convert}}.
+
+ Args:
+ model_flags_str: Serialized proto describing model properties, see
+ `toco/model_flags.proto`.
+ toco_flags_str: Serialized proto describing conversion properties, see
+ `toco/toco_flags.proto`.
+ input_data_str: Input data in serialized form (e.g. a graphdef is common)
+ Returns:
+ Converted model in serialized form (e.g. a TFLITE model is common).
+ Raises:
+ RuntimeError: When conversion fails, an exception is raised with the error
+ message embedded.
+ """
+ # TODO(aselle): When toco does not use fatal errors for failure, we can
+ # switch this on.
+ if EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
+ return _toco_convert_protos(model_flags_str, toco_flags_str, input_data_str)
+
+ # with tempfile.NamedTemporaryFile() as fp_toco, \
+ # tempfile.NamedTemporaryFile() as fp_model, \
+ # tempfile.NamedTemporaryFile() as fp_input, \
+ # tempfile.NamedTemporaryFile() as fp_output:
+ # fp_model.write(model_flags_str)
+ # fp_toco.write(toco_flags_str)
+ # fp_input.write(input_data_str)
+ # fp_model.flush()
+ # fp_toco.flush()
+ # fp_input.flush()
+
+ # cmd = [
+ # toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
+ # fp_output.name
+ # ]
+ # cmdline = " ".join(cmd)
+ # proc = subprocess.Popen(
+ # cmdline,
+ # shell=True,
+ # stdout=subprocess.PIPE,
+ # stderr=subprocess.STDOUT,
+ # close_fds=True)
+ # stdout, stderr = proc.communicate()
+ # exitcode = proc.returncode
+ # if exitcode == 0:
+ # stuff = fp_output.read()
+ # return stuff
+ # else:
+ # raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
+ # (stdout, stderr))
+
+
+def _tensor_name(x):
+ return x.name.split(":")[0]
+
+
+def toco_convert(input_data,
+ input_tensors,
+ output_tensors,
+ inference_type=FLOAT,
+ input_format=TENSORFLOW_GRAPHDEF,
+ output_format=TFLITE,
+ quantized_input_stats=None,
+ drop_control_dependency=True):
+ """Convert a model using TOCO from `input_format` to `output_format`.
+
+ Typically this is to convert from TensorFlow GraphDef to TFLite, in which
+ case the default `input_format` and `output_format` are sufficient.
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`).
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT)
+ quantized_input_stats: For each member of input_tensors the mean and
+ std deviation of training data. Only needed if `inference_type` is
+ `QUANTIZED_UINT8`.
+ drop_control_dependency: Drops control dependencies silently. This is due
+ to tf lite not supporting control dependencies.
+
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: If the input tensor type is unknown
+ RuntimeError: If TOCO fails to convert (in which case the runtime error's
+ error text will contain the TOCO error log)
+ """
+ toco = _toco_flags_pb2.TocoFlags()
+ toco.input_format = input_format
+ toco.output_format = output_format
+ model = _model_flags_pb2.ModelFlags()
+ model.drop_control_dependency = drop_control_dependency
+ toco.inference_type = inference_type
+ for idx, input_tensor in enumerate(input_tensors):
+ if input_tensor.dtype == _dtypes.float32:
+ tflite_input_type = FLOAT
+ elif input_tensor.dtype == _dtypes.int32:
+ tflite_input_type = INT32
+ elif input_tensor.dtype == _dtypes.int64:
+ tflite_input_type = INT64
+ # TODO(aselle): Insert strings when they are available
+ else:
+ raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
+ input_tensor.dtype))
+
+ input_array = model.input_arrays.add()
+
+ if inference_type == QUANTIZED_UINT8:
+ if tflite_input_type == FLOAT:
+ tflite_input_type = QUANTIZED_UINT8
+ input_array.mean, input_array.std = quantized_input_stats[idx]
+
+ input_array.name = _tensor_name(input_tensor)
+ input_array.shape.extend(map(int, input_tensor.get_shape()))
+ toco.input_types.append(tflite_input_type)
+
+ for output_tensor in output_tensors:
+ model.output_arrays.append(_tensor_name(output_tensor))
+
+ data = toco_convert_protos(model.SerializeToString(),
+ toco.SerializeToString(),
+ input_data.SerializeToString())
+ return data
+
+
+# remove_undocumented(__name__)
+
+del os
+del subprocess
+del tempfile
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
new file mode 100644
index 0000000000..da360aeb34
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Lite Python Interface: Sanity check."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class LiteTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
+ dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+ # Try running on valid graph
+ result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
+ self.assertTrue(result)
+ # TODO(aselle): remove tests that fail.
+ # Try running on identity graph (known fail)
+ # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
+ # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
new file mode 100644
index 0000000000..3e04d6f34f
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -0,0 +1,82 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_binary(
+ name = "upgrade_schema",
+ srcs = [
+ "upgrade_schema.py",
+ ],
+ data = [
+ "schema_v0.fbs",
+ "schema_v1.fbs",
+ "schema_v2.fbs",
+ "schema_v3.fbs",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:platform",
+ "@flatbuffers//:flatc",
+ ],
+)
+
+py_test(
+ name = "upgrade_schema_test",
+ size = "small",
+ srcs = ["upgrade_schema_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":upgrade_schema",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+exports_files([
+ "schema_v0.fbs",
+ "schema_v1.fbs",
+ "schema_v2.fbs",
+ "schema_v3.fbs",
+])
+
+load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library")
+
+# Generic schema for inference on device.
+flatbuffer_cc_library(
+ name = "schema_fbs",
+ srcs = ["schema.fbs"],
+)
+
+# Schema test to make sure we don't introduce backward incompatible changes
+# to schemas.
+cc_test(
+ name = "flatbuffer_compatibility_test",
+ size = "small",
+ srcs = ["flatbuffer_compatibility_test.cc"],
+ data = [
+ "schema.fbs",
+ "schema_v3.fbs",
+ ],
+ deps = [
+ "//tensorflow/core:lib_platform",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers//:flatc_library",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
new file mode 100644
index 0000000000..17ee0af8dd
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <gtest/gtest.h>
+#include "third_party/flatbuffers/include/flatbuffers/flatc.h"
+#include "tensorflow/core/platform/platform.h"
+
+#ifdef PLATFORM_GOOGLE
+#define TFLITE_TF_PREFIX "third_party/tensorflow/"
+#else
+#define TFLITE_TF_PREFIX "tensorflow/"
+#endif
+/// Load filename `name`
+bool LoadFileRaw(const char *name, std::string *buf) {
+ std::ifstream fp(name, std::ios::binary);
+ if (!fp) {
+ fprintf(stderr, "Failed to read '%s'\n", name);
+ return false;
+ }
+ std::string s((std::istreambuf_iterator<char>(fp)),
+ std::istreambuf_iterator<char>());
+ if (s.empty()) {
+ fprintf(stderr, "Read '%s' resulted in empty\n", name);
+ return false;
+ }
+ *buf = s;
+ return true;
+}
+
+bool ParseFile(flatbuffers::Parser *parser, const std::string &filename,
+ const std::string &contents) {
+ std::vector<const char *> include_directories;
+ auto local_include_directory = flatbuffers::StripFileName(filename);
+ include_directories.push_back(local_include_directory.c_str());
+ include_directories.push_back(nullptr);
+ if (!parser->Parse(contents.c_str(), include_directories.data(),
+ filename.c_str())) {
+ fprintf(stderr, "Failed to parse flatbuffer schema '%s'\n",
+ contents.c_str());
+ return false;
+ }
+ return true;
+}
+
+// Checks to make sure current schema in current code does not cause an
+// incompatibility.
+TEST(SchemaTest, TestCompatibility) {
+ // Read file contents of schemas into strings
+ // TODO(aselle): Need a reliable way to load files.
+ std::string base_contents, current_contents;
+ const char *base_filename =
+ TFLITE_TF_PREFIX "contrib/lite/schema/schema_v3.fbs";
+ const char *current_filename =
+ TFLITE_TF_PREFIX "contrib/lite/schema/schema.fbs";
+
+ ASSERT_TRUE(LoadFileRaw(base_filename, &base_contents));
+ ASSERT_TRUE(LoadFileRaw(current_filename, &current_contents));
+ // Parse the schemas
+ flatbuffers::Parser base_parser, current_parser;
+ std::vector<const char *> include_directories;
+ ASSERT_TRUE(ParseFile(&base_parser, base_filename, base_contents));
+ ASSERT_TRUE(ParseFile(&current_parser, current_filename, current_contents));
+ // Check that the schemas conform and fail if they don't
+ auto err = current_parser.ConformTo(base_parser);
+ if (!err.empty()) {
+ fprintf(stderr,
+ "Schemas don't conform:\n%s\n"
+ "In other words some change you made means that new parsers can't"
+ "parse old files.\n",
+ err.c_str());
+ FAIL();
+ }
+}
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
new file mode 100644
index 0000000000..ddb2ab792c
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -0,0 +1,346 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
+
+namespace tflite;
+
+// This corresponds to the version.
+file_identifier "TFL3";
+// File extension of any written files.
+file_extension "tflite";
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // An index that refers to the buffers table at the root of the model. Or,
+ // if there is no data buffer associated (i.e. intermediate results), then
+ // this is 0 (which refers to an always existant empty buffer).
+ //
+ // The data_buffer itself is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ buffer:uint;
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ // DEPTH_TO_SPACE = 5,
+ // DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ // FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ MUL = 18,
+ RELU = 19,
+ RELU1 = 20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+ EMBEDDING_LOOKUP_SPARSE = 33,
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+ EmbeddingLookupSparseOptions,
+ MulOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table MulOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:uint;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+enum CombinerType : byte {
+ SUM = 0,
+ MEAN = 1,
+ SQRTN = 2,
+}
+
+table EmbeddingLookupSparseOptions {
+ combiner:CombinerType;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+enum CustomOptionsFormat : byte {
+ FLEXBUFFERS = 0,
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:uint;
+
+ // Optional input and output tensors are indicated by -1.
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+ custom_options_format:CustomOptionsFormat;
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+// Table of raw data buffers (used for constant tensors). Referenced by tensors
+// by index.
+table Buffer {
+ data:[ubyte];
+}
+
+table Model {
+ // Version of the schema.
+ version:uint;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+
+ // Buffers of the model
+ buffers:[Buffer];
+
+}
+
+root_type Model;
+
diff --git a/tensorflow/contrib/lite/schema/schema_v0.fbs b/tensorflow/contrib/lite/schema/schema_v0.fbs
new file mode 100644
index 0000000000..852ea988f3
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v0.fbs
@@ -0,0 +1,247 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace tflite;
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // The data_buffer is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*4*3 + j*3 + k].
+ data_buffer:[ubyte];
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ CUSTOM = 0,
+ CONVOLUTION = 1,
+ DEPTHWISE_CONVOLUTION = 2,
+ CONCAT_EMBEDDINGS = 3,
+ LSH_PROJECTION = 4,
+ TANH = 5,
+ RELU = 6,
+ AVERAGE_POOL = 7,
+ MAX_POOL = 8,
+ L2_POOL = 9,
+ SIGMOID = 10,
+ SVDF = 11,
+ BasicRNN = 12,
+ RELU6 = 13,
+ EMBEDDING_LOOKUP = 14,
+ FULLY_CONNECTED = 15,
+ HASHTABLE_LOOKUP = 16,
+ SOFTMAX = 17,
+ CONCATENATION = 18,
+ LSTM = 19,
+ ADD = 20,
+ L2NORM = 21,
+ LOCAL_RESPONSE_NORM = 22,
+ RESIZE_BILINEAR = 23,
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ ConvolutionOptions,
+ DepthwiseConvolutionOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ PoolOptions,
+ SVDFOptions,
+ BasicRNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table ConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table PoolOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow BasicRNNCell.
+table BasicRNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:int;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table Model {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All operators, in execution order.
+ operators:[Operator];
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/schema_v1.fbs b/tensorflow/contrib/lite/schema/schema_v1.fbs
new file mode 100644
index 0000000000..06cd9408ed
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v1.fbs
@@ -0,0 +1,295 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+
+namespace tflite;
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // The data_buffer is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ data_buffer:[ubyte];
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ CUSTOM = 0,
+ CONVOLUTION = 1,
+ DEPTHWISE_CONVOLUTION = 2,
+ CONCAT_EMBEDDINGS = 3,
+ LSH_PROJECTION = 4,
+ TANH = 5,
+ RELU = 6,
+ AVERAGE_POOL = 7,
+ MAX_POOL = 8,
+ L2_POOL = 9,
+ SIGMOID = 10,
+ SVDF = 11,
+ BasicRNN = 12,
+ RELU6 = 13,
+ EMBEDDING_LOOKUP = 14,
+ FULLY_CONNECTED = 15,
+ HASHTABLE_LOOKUP = 16,
+ SOFTMAX = 17,
+ CONCATENATION = 18,
+ LSTM = 19,
+ ADD = 20,
+ L2NORM = 21,
+ LOCAL_RESPONSE_NORM = 22,
+ RESIZE_BILINEAR = 23,
+ CALL = 24,
+ RESHAPE = 25,
+ SKIP_GRAM = 26,
+ SPACE_TO_DEPTH = 27,
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ ConvolutionOptions,
+ DepthwiseConvolutionOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ PoolOptions,
+ SVDFOptions,
+ BasicRNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table ConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table PoolOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConvolutionOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow BasicRNNCell.
+table BasicRNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:int;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:int;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+table Model {
+ // Version of the schema.
+ version:int;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/schema_v2.fbs b/tensorflow/contrib/lite/schema/schema_v2.fbs
new file mode 100644
index 0000000000..96731c8aae
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v2.fbs
@@ -0,0 +1,303 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+
+namespace tflite;
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // The data_buffer is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ data_buffer:[ubyte];
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ // DEPTH_TO_SPACE = 5,
+ // DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ // FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ // MUL = 18,
+ RELU = 19,
+ // RELU1=20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:int;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:int;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+table Model {
+ // Version of the schema.
+ version:int;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/schema_v3.fbs b/tensorflow/contrib/lite/schema/schema_v3.fbs
new file mode 100644
index 0000000000..cedefe08f3
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/schema_v3.fbs
@@ -0,0 +1,326 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
+
+namespace tflite;
+
+// This corresponds to the version (4).
+file_identifier "TFL3";
+// File extension of any written files.
+file_extension "tflite";
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+}
+
+// Parameters for converting a quantized tensor back to float. Given a
+// quantized value q, the corresponding float value f should be:
+// f = scale * (q - zero_point)
+table QuantizationParameters {
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float];
+ zero_point:[long];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, number of channels, height, width] (That's
+ // Tensorflow's NCHW).
+ shape:[int];
+ type:TensorType;
+ // An index that refers to the buffers table at the root of the model. Or,
+ // if there is no data buffer associated (i.e. intermediate results), then
+ // this is 0 (which refers to an always existant empty buffer).
+ //
+ // The data_buffer itself is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ buffer:uint;
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+}
+
+// A list of builtin operators. Builtin operators a slighlty faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+enum BuiltinOperator : byte {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ // DEPTH_TO_SPACE = 5,
+ // DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ // FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ // MUL = 18,
+ RELU = 19,
+ // RELU1=20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+
+}
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+}
+
+enum Padding : byte { SAME, VALID }
+
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
+table ResizeBilinearOptions {
+ new_height:int;
+ new_width:int;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:uint;
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ builtin_code:BuiltinOperator;
+ custom_code:string;
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:uint;
+
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+}
+
+// The root type, defining a model.
+table SubGraph {
+ // A list of all tensors used in this model.
+ tensors:[Tensor];
+
+ // Indices of the input tensors.
+ inputs:[int];
+
+ // Indices of the output tensors.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of subgraph (used for debugging).
+ name:string;
+}
+
+// Table of raw data buffers (used for constant tensors). Referenced by tensors
+// by index.
+table Buffer {
+ data:[ubyte];
+}
+
+table Model {
+ // Version of the schema.
+ version:uint;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+
+ // Buffers of the model.
+ // NOTE: It is required that the first entry in here is always an empty
+ // buffer. This is so that the default buffer index of zero in Tensor
+ // will always refer to a valid empty buffer.
+ buffers:[Buffer];
+
+}
+
+root_type Model;
diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/contrib/lite/schema/upgrade_schema.py
new file mode 100644
index 0000000000..320c7138d2
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/upgrade_schema.py
@@ -0,0 +1,341 @@
+# ==============================================================================
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Upgrade script to move from pre-release schema to new schema.
+
+Usage examples:
+
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin
+bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import contextlib
+import json
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+
+import tensorflow as tf
+from tensorflow.python.platform import resource_loader
+
+parser = argparse.ArgumentParser(
+ description="Script to move TFLite models from pre-release schema to"
+ " new schema.")
+parser.add_argument(
+ "input",
+ type=str,
+ help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
+parser.add_argument(
+ "output",
+ type=str,
+ help="Output json or bin TensorFlow lite model compliant with"
+ "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
+
+
+# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
+@contextlib.contextmanager
+def TemporaryDirectoryResource():
+ temporary = tempfile.mkdtemp()
+ try:
+ yield temporary
+ finally:
+ shutil.rmtree(temporary)
+
+
+class Converter(object):
+ """Converts TensorFlow flatbuffer models from old to new version of schema.
+
+ This can convert between any version to the latest version. It uses
+ an incremental upgrade strategy to go from version to version.
+
+ Usage:
+ converter = Converter()
+ converter.Convert("a.tflite", "a.json")
+ converter.Convert("b.json", "b.tflite")
+ """
+
+ def __init__(self):
+ # TODO(aselle): make this work in the open source version with better
+ # path.
+ self._flatc_path = resource_loader.get_path_to_datafile(
+ "../../../../flatbuffers/flatc")
+
+ def FindSchema(base_name):
+ return resource_loader.get_path_to_datafile("%s" % base_name)
+
+ # Supported schemas for upgrade.
+ self._schemas = [
+ (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
+ (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
+ (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
+ (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design.
+ ]
+ # Ensure schemas are sorted, and extract latest version and upgrade
+ # dispatch function table.
+ self._schemas.sort()
+ self._new_version, self._new_schema = self._schemas[-1][:2]
+ self._upgrade_dispatch = dict(
+ (version, dispatch)
+ for version, unused1, unused2, dispatch in self._schemas)
+
+ def _Read(self, input_file, schema, raw_binary=False):
+ """Read a tflite model assuming the given flatbuffer schema.
+
+ If `input_file` is in bin, then we must use flatc to convert the schema
+ from binary to json.
+
+ Args:
+ input_file: a binary (flatbuffer) or json file to read from. Extension
+ must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
+ FlatBuffer JSON.
+ schema: which schema to use for reading
+ raw_binary: whether to assume raw_binary (versions previous to v3)
+ that lacked file_identifier require this.
+
+ Raises:
+ RuntimeError: When flatc cannot be invoked.
+ ValueError: When the extension is not json or bin.
+
+ Returns:
+ A dictionary representing the read tflite model.
+ """
+ raw_binary = ["--raw-binary"] if raw_binary else []
+ with TemporaryDirectoryResource() as tempdir:
+ basename = os.path.basename(input_file)
+ basename_no_extension, extension = os.path.splitext(basename)
+ if extension in [".bin", ".tflite"]:
+ # Convert to json using flatc
+ returncode = subprocess.call([
+ self._flatc_path,
+ "-t",
+ "--strict-json",
+ "--defaults-json",
+ ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
+ if returncode != 0:
+ raise RuntimeError("flatc failed to convert from binary to json.")
+ json_file = os.path.join(tempdir, basename_no_extension + ".json")
+ if not os.path.exists(json_file):
+ raise RuntimeError("Could not find %r" % json_file)
+ elif extension == ".json":
+ json_file = input_file
+ else:
+ raise ValueError("Invalid extension on input file %r" % input_file)
+ return json.load(open(json_file))
+
+ def _Write(self, data, output_file):
+ """Output a json or bin version of the flatbuffer model.
+
+ Args:
+ data: Dict representing the TensorFlow Lite model to write.
+ output_file: filename to write the converted flatbuffer to. (json,
+ tflite, or bin extension is required).
+ Raises:
+ ValueError: When the extension is not json or bin
+ RuntimeError: When flatc fails to convert json data to binary.
+ """
+ _, extension = os.path.splitext(output_file)
+ with TemporaryDirectoryResource() as tempdir:
+ if extension == ".json":
+ json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
+ elif extension in [".tflite", ".bin"]:
+ input_json = os.path.join(tempdir, "temp.json")
+ with open(input_json, "w") as fp:
+ json.dump(data, fp, sort_keys=True, indent=2)
+ returncode = subprocess.call([
+ self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
+ tempdir, self._new_schema, input_json
+ ])
+ if returncode != 0:
+ raise RuntimeError("flatc failed to convert upgraded json to binary.")
+
+ shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
+ else:
+ raise ValueError("Invalid extension on output file %r" % output_file)
+
+ def _Upgrade0To1(self, data):
+ """Upgrade data from Version 0 to Version 1.
+
+ Changes: Added subgraphs (which contains a subset of formally global
+ entries).
+
+ Args:
+ data: Dictionary representing the TensorFlow lite data to be upgraded.
+ This will be modified in-place to be an upgraded version.
+ """
+ subgraph = {}
+ for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
+ subgraph[key_to_promote] = data[key_to_promote]
+ del data[key_to_promote]
+ data["subgraphs"] = [subgraph]
+
+ def _Upgrade1To2(self, data):
+ """Upgrade data from Version 1 to Version 2.
+
+ Changes: Rename operators to Conform to NN API.
+
+ Args:
+ data: Dictionary representing the TensorFlow lite data to be upgraded.
+ This will be modified in-place to be an upgraded version.
+ Raises:
+ ValueError: Throws when model builtins are numeric rather than symbols.
+ """
+
+ def RemapOperator(opcode_name):
+ """Go from old schema op name to new schema op name.
+
+ Args:
+ opcode_name: String representing the ops (see :schema.fbs).
+ Returns:
+ Converted opcode_name from V1 to V2.
+ """
+ old_name_to_new_name = {
+ "CONVOLUTION": "CONV_2D",
+ "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
+ "AVERAGE_POOL": "AVERAGE_POOL_2D",
+ "MAX_POOL": "MAX_POOL_2D",
+ "L2_POOL": "L2_POOL_2D",
+ "SIGMOID": "LOGISTIC",
+ "L2NORM": "L2_NORMALIZATION",
+ "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
+ "Basic_RNN": "RNN",
+ }
+
+ return (old_name_to_new_name[opcode_name]
+ if opcode_name in old_name_to_new_name else opcode_name)
+
+ def RemapOperatorType(operator_type):
+ """Remap operator structs from old names to new names.
+
+ Args:
+ operator_type: String representing the builtin operator data type
+ string.
+ (see :schema.fbs).
+ Returns:
+ Upgraded builtin operator data type as a string.
+ """
+ old_to_new = {
+ "PoolOptions": "Pool2DOptions",
+ "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
+ "ConvolutionOptions": "Conv2DOptions",
+ "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
+ "BasicRNNOptions": "RNNOptions",
+ }
+ return (old_to_new[operator_type]
+ if operator_type in old_to_new else operator_type)
+
+ for subgraph in data["subgraphs"]:
+ for ops in subgraph["operators"]:
+ ops["builtin_options_type"] = RemapOperatorType(
+ ops["builtin_options_type"])
+
+ # Upgrade the operator codes
+ for operator_code in data["operator_codes"]:
+ if not isinstance(operator_code["builtin_code"], unicode):
+ raise ValueError("builtin_code %r is non-string. this usually means"
+ "your model has consistency problems." %
+ (operator_code["builtin_code"]))
+ operator_code["builtin_code"] = (RemapOperator(
+ operator_code["builtin_code"]))
+
+ def _Upgrade2To3(self, data):
+ """Upgrade data from Version 2 to Version 3.
+
+ Changed actual read-only tensor data to be in a buffers table instead
+ of inline with the tensor.
+
+ Args:
+ data: Dictionary representing the TensorFlow lite data to be upgraded.
+ This will be modified in-place to be an upgraded version.
+ """
+ buffers = [{"data": []}] # Start with 1 empty buffer
+ for subgraph in data["subgraphs"]:
+ if "tensors" not in subgraph:
+ continue
+ for tensor in subgraph["tensors"]:
+ if "data_buffer" not in tensor:
+ tensor["buffer"] = 0
+ else:
+ if tensor["data_buffer"]:
+ tensor[u"buffer"] = len(buffers)
+ buffers.append({"data": tensor["data_buffer"]})
+ else:
+ tensor["buffer"] = 0
+ del tensor["data_buffer"]
+ data["buffers"] = buffers
+
+ def _PerformUpgrade(self, data):
+ """Manipulate the `data` (parsed JSON) based on changes in format.
+
+ This incrementally will upgrade from version to version within data.
+
+ Args:
+ data: Dictionary representing the TensorFlow data. This will be upgraded
+ in place.
+ """
+ while data["version"] < self._new_version:
+ self._upgrade_dispatch[data["version"]](data)
+ data["version"] += 1
+
+ def Convert(self, input_file, output_file):
+ """Perform schema conversion from input_file to output_file.
+
+ Args:
+ input_file: Filename of TensorFlow Lite data to convert from. Must
+ be `.json` or `.bin` extension files for JSON or Binary forms of
+ the TensorFlow FlatBuffer schema.
+ output_file: Filename to write to. Extension also must be `.json`
+ or `.bin`.
+
+ Raises:
+ RuntimeError: Generated when none of the upgrader supported schemas
+ matche the `input_file` data.
+ """
+ # Read data in each schema (since they are incompatible). Version is
+ # always present. Use the read data that matches the version of the
+ # schema.
+ for version, schema, raw_binary, _ in self._schemas:
+ try:
+ data_candidate = self._Read(input_file, schema, raw_binary)
+ except RuntimeError:
+ continue # Skip and hope another schema works
+ if "version" not in data_candidate: # Assume version 1 if not present.
+ data_candidate["version"] = 1
+ elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild.
+ data_candidate["version"] = 1
+
+ if data_candidate["version"] == version:
+ self._PerformUpgrade(data_candidate)
+ self._Write(data_candidate, output_file)
+ return
+ raise RuntimeError("No schema that the converter understands worked with "
+ "the data file you provided.")
+
+
+def main(argv):
+ del argv
+ Converter().Convert(FLAGS.input, FLAGS.output)
+
+
+if __name__ == "__main__":
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/lite/schema/upgrade_schema_test.py b/tensorflow/contrib/lite/schema/upgrade_schema_test.py
new file mode 100644
index 0000000000..475cdb9d8b
--- /dev/null
+++ b/tensorflow/contrib/lite/schema/upgrade_schema_test.py
@@ -0,0 +1,317 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Testing for updating TensorFlow lite schema."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import tempfile
+from tensorflow.contrib.lite.schema import upgrade_schema as upgrade_schema_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test as test_lib
+
+EMPTY_TEST_SCHEMA_V1 = {
+ "version": 1,
+ "operator_codes": [],
+ "subgraphs": [],
+}
+
+EMPTY_TEST_SCHEMA_V3 = {
+ "version": 3,
+ "operator_codes": [],
+ "subgraphs": [],
+ "buffers": [{
+ "data": []
+ }]
+}
+
+TEST_SCHEMA_V0 = {
+ "operator_codes": [],
+ "tensors": [],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ "version": 0
+}
+
+TEST_SCHEMA_V3 = {
+ "operator_codes": [],
+ "buffers": [{
+ "data": []
+ }],
+ "subgraphs": [{
+ "tensors": [],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ }],
+ "version":
+ 3
+}
+
+FULL_TEST_SCHEMA_V1 = {
+ "version":
+ 1,
+ "operator_codes": [
+ {
+ "builtin_code": "CONVOLUTION"
+ },
+ {
+ "builtin_code": "DEPTHWISE_CONVOLUTION"
+ },
+ {
+ "builtin_code": "AVERAGE_POOL"
+ },
+ {
+ "builtin_code": "MAX_POOL"
+ },
+ {
+ "builtin_code": "L2_POOL"
+ },
+ {
+ "builtin_code": "SIGMOID"
+ },
+ {
+ "builtin_code": "L2NORM"
+ },
+ {
+ "builtin_code": "LOCAL_RESPONSE_NORM"
+ },
+ {
+ "builtin_code": "ADD"
+ },
+ {
+ "builtin_code": "Basic_RNN"
+ },
+ ],
+ "subgraphs": [{
+ "operators": [
+ {
+ "builtin_options_type": "PoolOptions"
+ },
+ {
+ "builtin_options_type": "DepthwiseConvolutionOptions"
+ },
+ {
+ "builtin_options_type": "ConvolutionOptions"
+ },
+ {
+ "builtin_options_type": "LocalResponseNormOptions"
+ },
+ {
+ "builtin_options_type": "BasicRNNOptions"
+ },
+ ],
+ }],
+ "description":
+ "",
+}
+
+FULL_TEST_SCHEMA_V3 = {
+ "version":
+ 3,
+ "operator_codes": [
+ {
+ "builtin_code": "CONV_2D"
+ },
+ {
+ "builtin_code": "DEPTHWISE_CONV_2D"
+ },
+ {
+ "builtin_code": "AVERAGE_POOL_2D"
+ },
+ {
+ "builtin_code": "MAX_POOL_2D"
+ },
+ {
+ "builtin_code": "L2_POOL_2D"
+ },
+ {
+ "builtin_code": "LOGISTIC"
+ },
+ {
+ "builtin_code": "L2_NORMALIZATION"
+ },
+ {
+ "builtin_code": "LOCAL_RESPONSE_NORMALIZATION"
+ },
+ {
+ "builtin_code": "ADD"
+ },
+ {
+ "builtin_code": "RNN"
+ },
+ ],
+ "subgraphs": [{
+ "operators": [
+ {
+ "builtin_options_type": "Pool2DOptions"
+ },
+ {
+ "builtin_options_type": "DepthwiseConv2DOptions"
+ },
+ {
+ "builtin_options_type": "Conv2DOptions"
+ },
+ {
+ "builtin_options_type": "LocalResponseNormalizationOptions"
+ },
+ {
+ "builtin_options_type": "RNNOptions"
+ },
+ ],
+ }],
+ "description":
+ "",
+ "buffers": [{
+ "data": []
+ }]
+}
+
+BUFFER_TEST_V2 = {
+ "operator_codes": [],
+ "buffers": [],
+ "subgraphs": [{
+ "tensors": [
+ {
+ "data_buffer": [1, 2, 3, 4]
+ },
+ {
+ "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8]
+ },
+ {
+ "data_buffer": []
+ },
+ ],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ }],
+ "version":
+ 2
+}
+
+BUFFER_TEST_V3 = {
+ "operator_codes": [],
+ "subgraphs": [{
+ "tensors": [
+ {
+ "buffer": 1
+ },
+ {
+ "buffer": 2
+ },
+ {
+ "buffer": 0
+ },
+ ],
+ "inputs": [],
+ "outputs": [],
+ "operators": [],
+ }],
+ "buffers": [
+ {
+ "data": []
+ },
+ {
+ "data": [1, 2, 3, 4]
+ },
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8]
+ },
+ ],
+ "version":
+ 3
+}
+
+
+def JsonDumpAndFlush(data, fp):
+ """Write the dictionary `data` to a JSON file `fp` (and flush).
+
+ Args:
+ data: in a dictionary that is JSON serializable.
+ fp: File-like object
+ """
+ json.dump(data, fp)
+ fp.flush()
+
+
+class TestSchemaUpgrade(test_util.TensorFlowTestCase):
+
+ def testNonExistantFile(self):
+ converter = upgrade_schema_lib.Converter()
+ non_existent = tempfile.mktemp(suffix=".json")
+ with self.assertRaisesRegexp(IOError, "No such file or directory"):
+ converter.Convert(non_existent, non_existent)
+
+ def testInvalidExtension(self):
+ converter = upgrade_schema_lib.Converter()
+ invalid_extension = tempfile.mktemp(suffix=".foo")
+ with self.assertRaisesRegexp(ValueError, "Invalid extension on input"):
+ converter.Convert(invalid_extension, invalid_extension)
+ with tempfile.NamedTemporaryFile(suffix=".json") as in_json:
+ JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json)
+ with self.assertRaisesRegexp(ValueError, "Invalid extension on output"):
+ converter.Convert(in_json.name, invalid_extension)
+
+ def CheckConversion(self, data_old, data_expected):
+ """Given a data dictionary, test upgrading to current version.
+
+ Args:
+ data_old: TFLite model as a dictionary (arbitrary version).
+ data_expected: TFLite model as a dictionary (upgraded).
+ """
+ converter = upgrade_schema_lib.Converter()
+ with tempfile.NamedTemporaryFile(suffix=".json") as in_json, \
+ tempfile.NamedTemporaryFile(suffix=".json") as out_json, \
+ tempfile.NamedTemporaryFile(suffix=".bin") as out_bin, \
+ tempfile.NamedTemporaryFile(suffix=".tflite") as out_tflite:
+ JsonDumpAndFlush(data_old, in_json)
+ # Test JSON output
+ converter.Convert(in_json.name, out_json.name)
+ # Test binary output
+ # Convert to .tflite and then to .bin and check if binary is equal
+ converter.Convert(in_json.name, out_tflite.name)
+ converter.Convert(out_tflite.name, out_bin.name)
+ self.assertEqual(open(out_bin.name).read(), open(out_tflite.name).read())
+ # Test that conversion actually produced successful new json.
+ converted_schema = json.load(out_json)
+ self.assertEqual(converted_schema, data_expected)
+
+ def testAlreadyUpgraded(self):
+ """A file already at version 3 should stay at version 3."""
+ self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3)
+ self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3)
+ self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3)
+
+ # Disable this while we have incorrectly versioned structures around.
+ # def testV0Upgrade_IntroducesSubgraphs(self):
+ # """V0 did not have subgraphs; check to make sure they get introduced."""
+ # self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3)
+
+ def testV1Upgrade_RenameOps(self):
+ """V1 had many different names for ops; check to make sure they rename."""
+ self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3)
+ self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3)
+
+ def testV2Upgrade_CreateBuffers(self):
+ """V2 did not have buffers; check to make sure they are created."""
+ self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3)
+
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc
new file mode 100644
index 0000000000..4aab244989
--- /dev/null
+++ b/tensorflow/contrib/lite/simple_memory_arena.cc
@@ -0,0 +1,136 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+
+#include <cstring>
+#include <limits>
+#include <vector>
+
+namespace {
+
+template <typename T>
+T AlignTo(size_t alignment, T offset) {
+ return offset % alignment == 0 ? offset
+ : offset + (alignment - offset % alignment);
+}
+
+} // namespace
+
+namespace tflite {
+
+TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context,
+ size_t alignment, size_t size,
+ ArenaAlloc* new_alloc) {
+ TF_LITE_ENSURE(context, alignment < arena_alignment_);
+
+ size_t current_top = 0;
+
+ if (!allocs_.empty()) {
+ auto last = allocs_.rbegin();
+ current_top = last->offset + last->size;
+ }
+
+ // If we don't find a better gap just allocate at the end of the buffer.
+ size_t best_offset = AlignTo(alignment, current_top);
+ size_t best_offset_fit = std::numeric_limits<size_t>::max();
+ auto best_insertion_it = allocs_.end();
+
+ // Go through the sorted allocs and look at the gaps between them.
+ size_t current_offset = 0;
+ for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
+ size_t aligned_current_offset = AlignTo(alignment, current_offset);
+ // If we found a gap larger than required size, and smaller than previous
+ // best fit, take it.
+ if (aligned_current_offset + size <= it->offset &&
+ it->offset - current_offset < best_offset_fit) {
+ best_offset = aligned_current_offset;
+ best_offset_fit = it->offset - current_offset;
+ best_insertion_it = it;
+ }
+ current_offset = it->offset + it->size;
+ }
+
+ // Update the required buffer size.
+ high_water_mark_ = std::max(high_water_mark_, best_offset + size);
+
+ new_alloc->offset = best_offset;
+ new_alloc->size = size;
+ allocs_.insert(best_insertion_it, *new_alloc);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context,
+ const ArenaAlloc& alloc) {
+ int erased_allocs_count = 0;
+ auto it = allocs_.begin();
+ while (it != allocs_.end()) {
+ if (it->offset == alloc.offset) {
+ TF_LITE_ENSURE_EQ(context, it->size, alloc.size);
+ erased_allocs_count++;
+ it = allocs_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ TF_LITE_ENSURE_EQ(context, erased_allocs_count, 1);
+ return kTfLiteOk;
+}
+
+TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) {
+ size_t required_size = RequiredBufferSize();
+ if (required_size > underlying_buffer_size_) {
+ char* new_alloc = new char[required_size];
+ char* new_underlying_buffer_aligned_ptr = reinterpret_cast<char*>(
+ AlignTo(arena_alignment_, reinterpret_cast<intptr_t>(new_alloc)));
+
+ // If the arena had been previously allocated, copy over the old memory.
+ // Since Alloc pointers are offset based, they will remain valid in the new
+ // memory block.
+ if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) {
+ size_t copy_amount = std::min(
+ underlying_buffer_.get() + underlying_buffer_size_ -
+ underlying_buffer_aligned_ptr_,
+ new_alloc + required_size - new_underlying_buffer_aligned_ptr);
+ memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_,
+ copy_amount);
+ }
+
+ underlying_buffer_.reset(new_alloc);
+ underlying_buffer_size_ = required_size;
+ underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr;
+ }
+ commited_ = true;
+ return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError;
+}
+
+TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context,
+ const ArenaAlloc& alloc,
+ char** output_ptr) {
+ TF_LITE_ENSURE(context, commited_);
+ TF_LITE_ENSURE(context, output_ptr != nullptr);
+ *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset;
+ return kTfLiteOk;
+}
+
+TfLiteStatus SimpleMemoryArena::Clear() {
+ commited_ = false;
+ high_water_mark_ = 0;
+ allocs_.clear();
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
new file mode 100644
index 0000000000..0d0b7f9ff7
--- /dev/null
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+
+#include <list>
+#include <memory>
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// This little structure holds the offset and the size for a dynamic memory
+// allocation in the memory arena. When the arena is commited and the
+// underlying buffer is set, the alloc can be resolved into an actual memory
+// pointer.
+struct ArenaAlloc {
+ ArenaAlloc() : offset(0), size(0) {}
+
+ size_t offset;
+ size_t size;
+
+ inline bool operator<(const ArenaAlloc& other) const {
+ return offset < other.offset;
+ }
+};
+
+// This small class is responsible for allocating, dealocating and reusing
+// dynamic memory from a common underlying buffer. The arena can be used in
+// scenarios when the pattern of memory allocations and dealocations is
+// repetitive, e.g. running NN inference in multiple iterations.
+class SimpleMemoryArena {
+ public:
+ explicit SimpleMemoryArena(size_t arena_alignment)
+ : commited_(false),
+ arena_alignment_(arena_alignment),
+ high_water_mark_(0),
+ underlying_buffer_size_(0),
+ allocs_() {}
+
+ TfLiteStatus Allocate(TfLiteContext* context, size_t alignment, size_t size,
+ ArenaAlloc* new_alloc);
+
+ TfLiteStatus Deallocate(TfLiteContext* context, const ArenaAlloc& alloc);
+
+ inline size_t RequiredBufferSize() {
+ // Add in a small amount of padding to reduce the chance of resize events
+ // for small allocations.
+ size_t padding = arena_alignment_;
+ return arena_alignment_ + high_water_mark_ + padding;
+ }
+
+ TfLiteStatus Commit(TfLiteContext* context);
+
+ TfLiteStatus ResolveAlloc(TfLiteContext* context, const ArenaAlloc& alloc,
+ char** output_ptr);
+
+ TfLiteStatus Clear();
+
+ private:
+ bool commited_;
+ size_t arena_alignment_;
+ size_t high_water_mark_;
+ std::unique_ptr<char[]> underlying_buffer_;
+ size_t underlying_buffer_size_;
+ char* underlying_buffer_aligned_ptr_;
+ // TODO(maciekc): add list iterator to the ArenaAlloc to lookup quickly.
+ std::list<ArenaAlloc> allocs_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc
new file mode 100644
index 0000000000..ac676092c6
--- /dev/null
+++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/simple_memory_arena.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+
+TEST(SimpleMemoryArenaTest, BasicArenaOperations) {
+ TfLiteContext context;
+ SimpleMemoryArena arena(64);
+ ArenaAlloc allocs[6];
+
+ arena.Allocate(&context, 32, 2047, &allocs[0]);
+ arena.Allocate(&context, 32, 2047, &allocs[1]);
+ arena.Allocate(&context, 32, 2047, &allocs[2]);
+ arena.Deallocate(&context, allocs[0]);
+ arena.Allocate(&context, 32, 1023, &allocs[3]);
+ arena.Allocate(&context, 32, 2047, &allocs[4]);
+ arena.Deallocate(&context, allocs[1]);
+ arena.Allocate(&context, 32, 1023, &allocs[5]);
+
+ EXPECT_EQ(allocs[0].offset, 0);
+ EXPECT_EQ(allocs[1].offset, 2048);
+ EXPECT_EQ(allocs[2].offset, 4096);
+ EXPECT_EQ(allocs[3].offset, 0);
+ EXPECT_EQ(allocs[4].offset, 6144);
+ EXPECT_EQ(allocs[5].offset, 1024);
+}
+
+TEST(SimpleMemoryArenaTest, TestAfterClear) {
+ TfLiteContext context;
+ SimpleMemoryArena arena(64);
+ ArenaAlloc allocs[9];
+
+ arena.Allocate(&context, 32, 2047, &allocs[0]);
+ arena.Allocate(&context, 32, 2047, &allocs[1]);
+ arena.Allocate(&context, 32, 2047, &allocs[2]);
+ arena.Commit(&context);
+
+ EXPECT_EQ(allocs[0].offset, 0);
+ EXPECT_EQ(allocs[1].offset, 2048);
+ EXPECT_EQ(allocs[2].offset, 4096);
+
+ arena.Clear();
+
+ // Test with smaller allocs.
+ arena.Allocate(&context, 32, 1023, &allocs[3]);
+ arena.Allocate(&context, 32, 1023, &allocs[4]);
+ arena.Allocate(&context, 32, 1023, &allocs[5]);
+ arena.Commit(&context);
+
+ EXPECT_EQ(allocs[3].offset, 0);
+ EXPECT_EQ(allocs[4].offset, 1024);
+ EXPECT_EQ(allocs[5].offset, 2048);
+
+ arena.Clear();
+
+ // Test larger allocs which should require a reallocation.
+ arena.Allocate(&context, 32, 4095, &allocs[6]);
+ arena.Allocate(&context, 32, 4095, &allocs[7]);
+ arena.Allocate(&context, 32, 4095, &allocs[8]);
+ arena.Commit(&context);
+
+ EXPECT_EQ(allocs[6].offset, 0);
+ EXPECT_EQ(allocs[7].offset, 4096);
+ EXPECT_EQ(allocs[8].offset, 8192);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/contrib/lite/string.h
new file mode 100644
index 0000000000..ecd6f04ec2
--- /dev/null
+++ b/tensorflow/contrib/lite/string.h
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Abstract string. We don't want even absl at this level.
+#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+
+#include <string>
+#include "tensorflow/core/platform/platform.h"
+
+namespace tflite {
+
+#ifndef PLATFORM_GOOGLE
+using std::string;
+#endif
+
+} // namespace tflite
+
+#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
new file mode 100644
index 0000000000..cd41299d38
--- /dev/null
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/string_util.h"
+
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+namespace {
+
+// Convenient method to get pointer to int32_t.
+int32_t* GetIntPtr(char* ptr) { return reinterpret_cast<int32_t*>(ptr); }
+} // namespace
+
+void DynamicBuffer::AddString(const char* str, size_t len) {
+ data_.resize(data_.size() + len);
+ memcpy(data_.data() + offset_.back(), str, len);
+ offset_.push_back(offset_.back() + len);
+}
+
+void DynamicBuffer::AddString(const StringRef& string) {
+ AddString(string.str, string.len);
+}
+
+void DynamicBuffer::AddJoinedString(const std::vector<StringRef>& strings,
+ char separator) {
+ // Resize the data buffer.
+ int total_len = strings.size() - 1;
+ for (StringRef ref : strings) {
+ total_len += ref.len;
+ }
+ data_.resize(data_.size() + total_len);
+
+ int current_idx = 0;
+ for (StringRef ref : strings) {
+ char* dst = data_.data() + offset_.back() + current_idx;
+
+ // Fill separator if not first string.
+ if (current_idx != 0) {
+ *dst = separator;
+ ++dst;
+ ++current_idx;
+ }
+
+ // Fill content of the string.
+ memcpy(dst, ref.str, ref.len);
+ current_idx += ref.len;
+ }
+ offset_.push_back(offset_.back() + total_len);
+}
+
+void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
+ // Allocate sufficient memory to tensor buffer.
+ int32_t num_strings = offset_.size() - 1;
+ // Total bytes include:
+ // * size of content (data_.size)
+ // * offset of each tensor (sizeof(int32_t) * num_strings)
+ // * length of whole buffer (int32_t)
+ // * num of strings (int32_t).
+ int32_t bytes = data_.size() // size of content
+ + sizeof(int32_t) * (num_strings + 2); // size of header
+
+ // Output tensor will take over the ownership of tensor_buffer, and free it
+ // during Interpreter destruction.
+ char* tensor_buffer = static_cast<char*>(malloc(bytes));
+
+ // Set num of string
+ memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
+
+ // Set offset of strings.
+ int32_t start = sizeof(int32_t) * (num_strings + 2);
+ for (int i = 0; i < offset_.size(); i++) {
+ int32_t offset = start + offset_[i];
+ memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t));
+ }
+
+ // Copy data of strings.
+ memcpy(tensor_buffer + start, data_.data(), data_.size());
+
+ // Set tensor content pointer to tensor_buffer, and release original data.
+ auto dims = TfLiteIntArrayCreate(1);
+ dims->data[0] = num_strings;
+ TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params,
+ tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
+ tensor);
+}
+
+int GetStringCount(const TfLiteTensor* tensor) {
+ // The first integers in the raw buffer is the number of strings.
+ return *GetIntPtr(tensor->data.raw);
+}
+
+StringRef GetString(const TfLiteTensor* tensor, int string_index) {
+ int32_t* offset =
+ GetIntPtr(tensor->data.raw + sizeof(int32_t) * (string_index + 1));
+ return {
+ tensor->data.raw + (*offset),
+ (*(offset + 1)) - (*offset),
+ };
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
new file mode 100644
index 0000000000..12872d1123
--- /dev/null
+++ b/tensorflow/contrib/lite/string_util.h
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Util methods to read and write String tensors.
+// String tensors are considered to be char tensor with protocol.
+// [0, 3] 4 bytes: N, num of strings in the tensor in little endian.
+// [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian.
+// [(N+2)*4, (N+2)*4+3] 4 bytes: length of the whole char buffer.
+// [offset(i), offset(i+1) - 1] : content of i-th string.
+// Example of a string tensor:
+// [
+// 2, 0, 0, 0, # 2 strings.
+// 16, 0, 0, 0, # 0-th string starts from index 12.
+// 18, 0, 0, 0, # 1-st string starts from index 18.
+// 18, 0, 0, 0, # total length of array.
+// 'A', 'B', # 0-th string [16..17]: "AB"
+// ] # 1-th string, empty
+//
+// A typical usage:
+// In op.Eval(context, node):
+// DynamicBuffer buf;
+// # Add string "AB" to tensor, string is stored in dynamic buffer.
+// buf.AddString("AB", 2);
+// # Write content of DynamicBuffer to tensor in format of string tensor
+// # described above.
+// buf.WriteToTensor(tensor)
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+
+// Convenient structure to store string pointer and length.
+typedef struct {
+ char* str;
+ int len;
+} StringRef;
+
+// DynamicBuffer holds temporary buffer that will be used to create a dynamic
+// tensor. A typical usage is to initialize a DynamicBuffer object, fill in
+// content and call CreateStringTensor in op.Eval().
+class DynamicBuffer {
+ public:
+ DynamicBuffer() : offset_({0}) {}
+
+ // Add string to dynamic buffer by resizing the buffer and copying the data.
+ void AddString(const StringRef& string);
+
+ // Add string to dynamic buffer by resizing the buffer and copying the data.
+ void AddString(const char* str, size_t len);
+
+ // Join a list of string with separator, and add as a single string to the
+ // buffer.
+ void AddJoinedString(const std::vector<StringRef>& strings, char separator);
+
+ // Fill content into a string tensor.
+ void WriteToTensor(TfLiteTensor* tensor);
+
+ private:
+ // Data buffer to store contents of strings, not including headers.
+ std::vector<char> data_;
+ // Offset of the starting index of each string in data buffer.
+ std::vector<int32_t> offset_;
+};
+
+// Return num of strings in a String tensor.
+int GetStringCount(const TfLiteTensor* tensor);
+
+// Get String pointer and length of index-th string in tensor.
+// NOTE: This will not create a copy of string data.
+StringRef GetString(const TfLiteTensor* tensor, int string_index);
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
new file mode 100644
index 0000000000..5c351638dc
--- /dev/null
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/string_util.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+
+TEST(StringUtil, TestStringUtil) {
+ Interpreter interpreter;
+ interpreter.AddTensors(3);
+
+ TfLiteTensor* t0 = interpreter.tensor(0);
+ t0->type = kTfLiteString;
+ t0->allocation_type = kTfLiteDynamic;
+
+ TfLiteTensor* t1 = interpreter.tensor(1);
+ t1->type = kTfLiteString;
+ t1->allocation_type = kTfLiteDynamic;
+
+ char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'X', 'Y', 'Z'};
+
+ interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, {}, data,
+ 15);
+ TfLiteTensor* t2 = interpreter.tensor(2);
+ interpreter.AllocateTensors();
+
+ char s0[] = "ABC";
+ string s1 = "DEFG";
+ char s2[] = "";
+
+ // Write strings to tensors
+ DynamicBuffer buf0;
+ buf0.AddString(s0, 3);
+ DynamicBuffer buf1;
+ buf1.AddString(s1.data(), s1.length());
+ buf0.AddString(s2, 0);
+ buf0.WriteToTensor(t0);
+ buf1.WriteToTensor(t1);
+
+ // Read strings from tensors.
+ ASSERT_EQ(GetStringCount(t0), 2);
+ StringRef str_ref;
+ str_ref = GetString(t0, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC");
+ str_ref = GetString(t0, 1);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "");
+ ASSERT_EQ(t0->bytes, 19);
+
+ ASSERT_EQ(GetStringCount(t1), 1);
+ str_ref = GetString(t1, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "DEFG");
+ ASSERT_EQ(t1->bytes, 16);
+
+ ASSERT_EQ(GetStringCount(t2), 1);
+ str_ref = GetString(t2, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "XYZ");
+ ASSERT_EQ(t2->bytes, 15);
+}
+
+TEST(StringUtil, TestAddJoinedString) {
+ Interpreter interpreter;
+ interpreter.AddTensors(1);
+ TfLiteTensor* t0 = interpreter.tensor(0);
+ t0->type = kTfLiteString;
+ t0->allocation_type = kTfLiteDynamic;
+
+ char s0[] = "ABC";
+ char s1[] = "DEFG";
+ char s2[] = "";
+ char s3[] = "XYZ";
+
+ DynamicBuffer buf;
+ buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' ');
+ buf.WriteToTensor(t0);
+
+ ASSERT_EQ(GetStringCount(t0), 1);
+ StringRef str_ref;
+ str_ref = GetString(t0, 0);
+ ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC DEFG XYZ");
+ ASSERT_EQ(t0->bytes, 25);
+}
+
+TEST(StringUtil, TestEmptyList) {
+ Interpreter interpreter;
+ interpreter.AddTensors(1);
+ TfLiteTensor* t0 = interpreter.tensor(0);
+ t0->type = kTfLiteString;
+ t0->allocation_type = kTfLiteDynamic;
+ DynamicBuffer buf;
+ buf.WriteToTensor(t0);
+
+ ASSERT_EQ(GetStringCount(t0), 0);
+ ASSERT_EQ(t0->bytes, 8);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/testdata/0_subgraphs.bin b/tensorflow/contrib/lite/testdata/0_subgraphs.bin
new file mode 100644
index 0000000000..5606898d7f
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/0_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/2_subgraphs.bin b/tensorflow/contrib/lite/testdata/2_subgraphs.bin
new file mode 100644
index 0000000000..07308ba62b
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/2_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/empty_model.bin b/tensorflow/contrib/lite/testdata/empty_model.bin
new file mode 100644
index 0000000000..1762ca3938
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/empty_model.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/multi_add.bin b/tensorflow/contrib/lite/testdata/multi_add.bin
new file mode 100644
index 0000000000..e5048a3281
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/multi_add.json b/tensorflow/contrib/lite/testdata/multi_add.json
new file mode 100644
index 0000000000..97b931dba8
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add.json
@@ -0,0 +1,46 @@
+{
+ "version": 1,
+ "operator_codes": [
+ {
+ "builtin_code": "ADD"
+ }
+ ],
+ "subgraphs": [
+ {
+ "tensors": [
+ { "shape": [ 1, 8, 8, 3 ], "name": "a" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "b" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "c" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "d" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "i" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "x" },
+ { "shape": [ 1, 8, 8, 3 ], "name": "y" }
+ ],
+ "inputs": [ 0, 1, 2, 3 ],
+ "outputs": [ 5, 6 ],
+ "operators": [
+ {
+ "inputs": [ 1, 2 ],
+ "outputs": [ 4 ],
+ "builtin_options_type": "AddOptions",
+ "builtin_options": {
+ }
+ },
+ {
+ "inputs": [ 0, 4 ],
+ "outputs": [ 5 ],
+ "builtin_options_type": "AddOptions",
+ "builtin_options": {
+ }
+ },
+ {
+ "inputs": [ 3, 4 ],
+ "outputs": [ 6 ],
+ "builtin_options_type": "AddOptions",
+ "builtin_options": {
+ }
+ }
+ ]
+ }
+ ]
+}
diff --git a/tensorflow/contrib/lite/testdata/no_subgraphs.bin b/tensorflow/contrib/lite/testdata/no_subgraphs.bin
new file mode 100644
index 0000000000..5606898d7f
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/no_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/test_model.bin b/tensorflow/contrib/lite/testdata/test_model.bin
new file mode 100644
index 0000000000..2878b1f96e
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/test_model.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.bin b/tensorflow/contrib/lite/testdata/test_model_broken.bin
new file mode 100644
index 0000000000..9fd050cd4a
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/test_model_broken.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.json b/tensorflow/contrib/lite/testdata/test_model_broken.json
new file mode 100644
index 0000000000..b701eb9a25
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/test_model_broken.json
@@ -0,0 +1,62 @@
+{
+ "subgraphs": [
+ {
+ "inputs": [0, 1],
+ "outputs": [2, 3],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [0,1],
+ "outputs": [2]
+ },
+ {
+ "opcode_index": 1,
+ "inputs": [2],
+ "outputs": [3]
+ }
+ ],
+ "tensors": [
+ {
+ "shape" : [
+ 2
+ ],
+ "type" : "FLOAT32",
+ "name" : "input0",
+ "data_buffer" : [1,0,0,0]
+ },
+ {
+ "shape" : [
+ 3
+ ],
+ "type" : "FLOAT32",
+ "name" : "input1",
+ "data_buffer" : []
+ },
+ {
+ "shape" : [
+ 3
+ ],
+ "type" : "FLOAT32",
+ "name" : "out1",
+ "data_buffer" : []
+ },
+ {
+ "shape" : [
+ 3
+ ],
+ "type" : "FLOAT32",
+ "name" : "out2",
+ "data_buffer" : []
+ }
+ ],
+ }
+ ],
+ "operator_codes": [
+ {
+ "builtin_code": 0
+ },
+ {
+ "custom_code": "testing_op"
+ }
+ ]
+}
diff --git a/tensorflow/contrib/lite/testdata/two_subgraphs.bin b/tensorflow/contrib/lite/testdata/two_subgraphs.bin
new file mode 100644
index 0000000000..07308ba62b
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/two_subgraphs.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
new file mode 100644
index 0000000000..5e40a13d3c
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -0,0 +1,213 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite:build_def.bzl",
+ "gen_zipped_test_files",
+)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+gen_zipped_test_files(
+ name = "optest",
+ files = [
+ "add.zip",
+ "avg_pool.zip",
+ "concat.zip",
+ "constant.zip",
+ "control_dep.zip",
+ "conv.zip",
+ "depthwiseconv.zip",
+ "fully_connected.zip",
+ "fused_batch_norm.zip",
+ "global_batch_norm.zip",
+ "l2_pool.zip",
+ "l2norm.zip",
+ "local_response_norm.zip",
+ "max_pool.zip",
+ "mul.zip",
+ "relu.zip",
+ "relu1.zip",
+ "relu6.zip",
+ "reshape.zip",
+ "resize_bilinear.zip",
+ "sigmoid.zip",
+ "softmax.zip",
+ "space_to_depth.zip",
+ ],
+)
+
+py_binary(
+ name = "generate_examples",
+ srcs = ["generate_examples.py"],
+ data = [
+ "//tensorflow/contrib/lite/toco",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":generate_examples_report",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:graph_util",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "generate_examples_report",
+ srcs = ["generate_examples_report.py"],
+ srcs_version = "PY2AND3",
+)
+
+cc_library(
+ name = "parse_testdata_lib",
+ srcs = ["parse_testdata.cc"],
+ hdrs = ["parse_testdata.h"],
+ deps = [
+ ":message",
+ ":split",
+ ":test_runner",
+ "//tensorflow/contrib/lite:framework",
+ ],
+)
+
+cc_library(
+ name = "message",
+ srcs = ["message.cc"],
+ hdrs = ["message.h"],
+ deps = [":tokenize"],
+)
+
+cc_test(
+ name = "message_test",
+ srcs = ["message_test.cc"],
+ deps = [
+ ":message",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "split",
+ srcs = ["split.cc"],
+ hdrs = ["split.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "split_test",
+ size = "small",
+ srcs = ["split_test.cc"],
+ deps = [
+ ":split",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tflite_driver",
+ srcs = ["tflite_driver.cc"],
+ hdrs = ["tflite_driver.h"],
+ deps = [
+ ":split",
+ ":test_runner",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "tflite_driver_test",
+ size = "small",
+ srcs = ["tflite_driver_test.cc"],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ deps = [
+ ":tflite_driver",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tokenize",
+ srcs = ["tokenize.cc"],
+ hdrs = ["tokenize.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "tokenize_test",
+ srcs = ["tokenize_test.cc"],
+ deps = [
+ ":tokenize",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "test_runner",
+ hdrs = ["test_runner.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "test_runner_test",
+ srcs = ["test_runner_test.cc"],
+ deps = [
+ ":test_runner",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_binary(
+ name = "nnapi_example",
+ srcs = ["nnapi_example.cc"],
+ deps = [
+ ":parse_testdata_lib",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "generated_examples_zip_test",
+ size = "medium",
+ srcs = ["generated_examples_zip_test.cc"],
+ data = [":optest"],
+ shard_count = 10,
+ deps = [
+ ":parse_testdata_lib",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_googletest//:gtest",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
new file mode 100644
index 0000000000..86540d58a6
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -0,0 +1,1189 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generate a series of TensorFlow graphs that become tflite test cases.
+
+Usage:
+
+generate_examples <output directory> zipped
+
+bazel run //tensorflow/contrib/lite/testing:generate_examples
+ third_party/tensorflow/contrib/lite/testing/generated_examples zipped
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import itertools
+import os
+import re
+import sys
+import tempfile
+import traceback
+import zipfile
+import numpy as np
+from six import StringIO
+import tensorflow as tf
+from google.protobuf import text_format
+# TODO(aselle): switch to TensorFlow's resource_loader
+from tensorflow.contrib.lite.testing import generate_examples_report as report_lib
+from tensorflow.python.framework import graph_util as tf_graph_util
+
+parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
+parser.add_argument("output_path",
+ help="Directory where the outputs will be go.")
+# TODO(ahentz): remove this flag
+parser.add_argument("type", help="zipped")
+parser.add_argument("--zip_to_output",
+ type=str,
+ help="Particular zip to output.",
+ required=False)
+parser.add_argument("--toco",
+ type=str,
+ help="Path to toco tool.",
+ required=True)
+parser.add_argument(
+ "--known_bugs_are_errors",
+ action="store_true",
+ help=("If a particular model is affected by a known bug,"
+ " count it as a toco error."))
+parser.add_argument(
+ "--ignore_toco_errors",
+ action="store_true",
+ help="Raise an exception if any toco error is encountered.")
+parser.add_argument(
+ "--save_graphdefs",
+ action="store_true",
+ help="Include intermediate graphdefs in the output zip files.")
+
+
+RANDOM_SEED = 342
+TEST_INPUT_DEPTH = 3
+
+
+# A map from regular expression to bug number. Any test failure with label
+# matching the expression will be considered due to the corresponding bug.
+KNOWN_BUGS = {
+ # TOCO doesn't support scalars as input.
+ r"relu.*input_shape=\[\]": "67587484",
+ r"sigmoid.*input_shape=\[\]": "67645668",
+ # Concat doesn't work with a single input tensor
+ r"concat.*num_tensors=1": "67378344",
+ # Transposition in MatMul is not supported.
+ r"fully_connected.*transpose_.=True": "67586970",
+ # Softmax graphs are too complex.
+ r"softmax.*dim=0": "67749831",
+ r"softmax.*input_shape=\[1,3,4,3\]": "67749831",
+ # SpaceToDepth only supports float32.
+ r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
+}
+
+
+def toco_options(data_types,
+ input_arrays,
+ output_arrays,
+ shapes,
+ drop_control_dependency):
+ """Create TOCO options to process a model.
+
+ Args:
+ data_types: input and inference types used by TOCO.
+ input_arrays: names of the input tensors
+ output_arrays: name of the output tensors
+ shapes: shapes of the input tensors
+ drop_control_dependency: whether to ignore control dependency nodes.
+
+ Returns:
+ the options in a string.
+ """
+ shape_str = ":".join([",".join(str(y) for y in x) for x in shapes])
+ inference_type = "FLOAT"
+ # TODO(ahentz): if we get multi-input quantization to work we need this
+ # to change
+ if data_types[0] == "QUANTIZED_UINT8":
+ inference_type = "QUANTIZED_UINT8"
+ s = (" --input_types=%s" % ",".join(data_types) +
+ " --inference_type=%s" % inference_type +
+ " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
+ " --input_arrays=%s" % ",".join(input_arrays) +
+ " --input_shapes=%s" % shape_str +
+ " --output_arrays=%s" % ",".join(output_arrays))
+ if drop_control_dependency:
+ s += " --drop_control_dependency"
+ return s
+
+
+def write_toco_options(filename,
+ data_types,
+ input_arrays,
+ output_arrays,
+ shapes,
+ drop_control_dependency=False):
+ """Create TOCO options to process a model.
+
+ Args:
+ filename: Filename to write the options to.
+ data_types: input and inference types used by TOCO.
+ input_arrays: names of the input tensors
+ output_arrays: names of the output tensors
+ shapes: shapes of the input tensors
+ drop_control_dependency: whether to ignore control dependency nodes.
+ """
+ with open(filename, "w") as fp:
+ fp.write(
+ toco_options(
+ data_types=data_types,
+ input_arrays=input_arrays,
+ output_arrays=output_arrays,
+ shapes=shapes,
+ drop_control_dependency=drop_control_dependency))
+
+
+def write_examples(fp, examples):
+ """Given a list `examples`, write a text format representation.
+
+ The file format is csv like with a simple repeated pattern. We would ike
+ to use proto here, but we can't yet due to interfacing with the Android
+ team using this format.
+
+ Args:
+ fp: File-like object to write to.
+ examples: Example dictionary consiting of keys "inputs" and "outputs"
+ """
+
+ def write_tensor(fp, x):
+ """Write tensor in file format supported by TFLITE example."""
+ fp.write("dtype,%s\n" % x.dtype)
+ fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
+ # Output 9 digits after the point to ensure the precision is good enough.
+ values = ["{:.9f}".format(value) for value in list(x.flatten())]
+ fp.write("values," + ",".join(values) + "\n")
+
+ fp.write("test_cases,%d\n" % len(examples))
+ for example in examples:
+ fp.write("inputs,%d\n" % len(example["inputs"]))
+ for i in example["inputs"]:
+ write_tensor(fp, i)
+ fp.write("outputs,%d\n" % len(example["outputs"]))
+ for i in example["outputs"]:
+ write_tensor(fp, i)
+
+
+def write_test_cases(fp, model_name, examples):
+ """Given a dictionary of `examples`, write a text format representation.
+
+ The file format is protocol-buffer-like, even though we don't use proto due
+ to the needs of the Android team.
+
+ Args:
+ fp: File-like object to write to.
+ model_name: Filename where the model was written to, relative to filename.
+ examples: Example dictionary consiting of keys "inputs" and "outputs"
+ """
+
+ fp.write("load_model: %s\n" % os.path.basename(model_name))
+ for example in examples:
+ fp.write("reshape {\n")
+ for t in example["inputs"]:
+ fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n")
+ fp.write("}\n")
+ fp.write("invoke {\n")
+
+ for t in example["inputs"]:
+ values = ["{:.9f}".format(value) for value in list(t.flatten())]
+ fp.write(" input: \"" + ",".join(values) + "\"\n")
+ for t in example["outputs"]:
+ values = ["{:.9f}".format(value) for value in list(t.flatten())]
+ fp.write(" output: \"" + ",".join(values) + "\"\n")
+ fp.write("}\n")
+
+
+_TF_TYPE_INFO = {
+ tf.float32: (np.float32, "FLOAT"),
+ tf.float16: (np.float16, "FLOAT"),
+ tf.int32: (np.int32, "INT32"),
+ tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
+ tf.int64: (np.int64, "INT64"),
+}
+
+
+def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
+ """Build tensor data spreading the range [min_value, max_value)."""
+
+ if dtype in _TF_TYPE_INFO:
+ dtype = _TF_TYPE_INFO[dtype][0]
+
+ if dtype in (tf.float32, tf.float16):
+ value = (max_value-min_value)*np.random.random_sample(shape)+min_value
+ elif dtype in (tf.int32, tf.uint8, tf.int64):
+ value = np.random.random_integers(min_value, max_value, shape)
+ return value.astype(dtype)
+
+
+def freeze_graph(session, outputs):
+ """Freeze the current graph.
+
+ Args:
+ session: Tensorflow sessions containing the graph
+ outputs: List of output tensors
+
+ Returns:
+ The frozen graph_def.
+ """
+ return tf_graph_util.convert_variables_to_constants(
+ session, session.graph.as_graph_def(), [x.op.name for x in outputs])
+
+
+def make_control_dep_tests(zip_path):
+ """Make a set of tests that use control dependencies."""
+
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32)
+ assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1)
+ with tf.control_dependencies([assert_op]):
+ out = tf.nn.conv2d(input_tensor, filter_value,
+ strides=(1, 1, 1, 1), padding="SAME")
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(tf.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
+ drop_control_dependency=True)
+
+
+def toco_convert(graph_def_str, input_tensors, output_tensors,
+ drop_control_dependency=False):
+ """Convert a model's graph def into a tflite model.
+
+ NOTE: this currently shells out to the toco binary, but we would like
+ convert to Python API tooling in the future.
+
+ Args:
+ graph_def_str: Graph def proto in serialized string format.
+ input_tensors: List of input tensor tuples `(name, shape, type)`
+ output_tensors: List of output tensors (names)
+ drop_control_dependency: whether to ignore control dependency nodes.
+
+ Returns:
+ output tflite model, log_txt from conversion
+ or None, log_txt if it did not convert properly.
+ """
+ data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
+ opts = toco_options(
+ data_types=data_types,
+ input_arrays=[x[0] for x in input_tensors],
+ shapes=[x[1] for x in input_tensors],
+ output_arrays=output_tensors,
+ drop_control_dependency=drop_control_dependency)
+
+ with tempfile.NamedTemporaryFile() as graphdef_file, \
+ tempfile.NamedTemporaryFile() as output_file, \
+ tempfile.NamedTemporaryFile("w+") as stdout_file:
+ graphdef_file.write(graph_def_str)
+ graphdef_file.flush()
+
+ # TODO(aselle): Switch this to subprocess at some point.
+ cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
+ (bin_path, graphdef_file.name, output_file.name, opts,
+ stdout_file.name))
+ exit_code = os.system(cmd)
+ log = (
+ cmd + "exited with code %d" % exit_code + "\n------------------\n" +
+ stdout_file.read())
+ return (None if exit_code != 0 else output_file.read()), log
+
+
+def make_zip_of_tests(zip_path,
+ test_parameters,
+ make_graph,
+ make_test_inputs,
+ drop_control_dependency=False):
+ """Helper to make a zip file of a bunch of TensorFlow models.
+
+ This does a cartestian product of the dictionary of test_parameters and
+ calls make_graph() for each item in the cartestian product set.
+ If the graph is built successfully, then make_test_inputs() is called to
+ build expected input/output value pairs. The model is then converted to tflite
+ with toco, and the examples are serialized with the tflite model into a zip
+ file (2 files per item in the cartesian product set).
+
+ Args:
+ zip_path: Path of zip file to write
+ test_parameters: Dictionary mapping to lists for each parameter.
+ e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
+ make_graph: function that takes current parameters and returns tuple
+ `[input1, input2, ...], [output1, output2, ...]`
+ make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
+ `output_tensors` and returns tuple `(input_values, output_values)`.
+ drop_control_dependency: whether to ignore control dependency nodes.
+ Raises:
+ RuntimeError: if there are toco errors that can't be ignored.
+ """
+
+ # TODO(aselle): Make this allow multiple inputs outputs.
+ archive = zipfile.PyZipFile(zip_path, "w")
+ zip_manifest = []
+ convert_report = []
+ toco_errors = 0
+ for parameters in test_parameters:
+ keys = parameters.keys()
+ for curr in itertools.product(*parameters.values()):
+ label = zip_path.replace(".zip", "") + (",".join(
+ "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
+ if label[0] == "/":
+ label = label[1:]
+ param_dict = dict(zip(keys, curr))
+
+ def build_example(label, param_dict_real):
+ """Build the model with parameter values set in param_dict_real.
+
+ Args:
+ label: Label of the model (i.e. the filename in the zip).
+ param_dict_real: Parameter dictionary (arguments to the factories
+ make_graph and make_test_inputs)
+ Returns:
+ (tflite_model_binary, report) where tflite_model_binary is the
+ serialized flatbuffer as a string and report is a dictionary with
+ keys `toco_log` (log of toco conversion), `tf_log` (log of tf
+ conversion), `toco` (a string of success status of the conversion),
+ `tf` (a string success status of the conversion).
+ """
+
+ np.random.seed(RANDOM_SEED)
+ report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED}
+
+ # Build graph
+ report["tf_log"] = ""
+ report["toco_log"] = ""
+ tf.reset_default_graph()
+
+ try:
+ inputs, outputs = make_graph(param_dict_real)
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
+
+ sess = tf.Session()
+ try:
+ baseline_inputs, baseline_outputs = (make_test_inputs(
+ param_dict_real, sess, inputs, outputs))
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
+ report["toco"] = report_lib.FAILED
+ report["tf"] = report_lib.SUCCESS
+
+ # Convert graph to toco
+ tflite_model_binary, toco_log = toco_convert(
+ sess.graph_def.SerializeToString(),
+ [(input_tensor.name.split(":")[0], input_tensor.get_shape(),
+ input_tensor.dtype) for input_tensor in inputs],
+ [out.name.split(":")[0]
+ for out in outputs], drop_control_dependency)
+ report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None
+ else report_lib.FAILED)
+ report["toco_log"] = toco_log
+
+ if FLAGS.save_graphdefs:
+ archive.writestr(label + ".pb",
+ text_format.MessageToString(sess.graph_def),
+ zipfile.ZIP_DEFLATED)
+
+ if tflite_model_binary:
+ archive.writestr(label + ".bin", tflite_model_binary,
+ zipfile.ZIP_DEFLATED)
+ example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
+
+ example_fp = StringIO()
+ write_examples(example_fp, [example])
+ archive.writestr(label + ".inputs",
+ example_fp.getvalue(), zipfile.ZIP_DEFLATED)
+
+ example_fp2 = StringIO()
+ write_test_cases(example_fp2, label + ".bin", [example])
+ archive.writestr(label + "_tests.txt",
+ example_fp2.getvalue(), zipfile.ZIP_DEFLATED)
+
+ zip_manifest.append(label + "\n")
+
+ return tflite_model_binary, report
+
+ _, report = build_example(label, param_dict)
+
+ if report["toco"] == report_lib.FAILED:
+ ignore_error = False
+ if not FLAGS.known_bugs_are_errors:
+ for pattern, bug_number in KNOWN_BUGS.items():
+ if re.search(pattern, label):
+ print("Ignored TOCO error due to bug %s" % bug_number)
+ ignore_error = True
+ if not ignore_error:
+ toco_errors += 1
+ print("-----------------\ntoco error!\n%s\n-----------------\n" %
+ report["toco_log"])
+
+ convert_report.append((param_dict, report))
+ report_io = StringIO()
+ report_lib.make_report_table(report_io, zip_path, convert_report)
+ archive.writestr("report.html", report_io.getvalue())
+
+ archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED)
+
+ # Log statistics of what succeeded
+ total_conversions = len(convert_report)
+ tf_success = sum(1 for x in convert_report
+ if x[1]["tf"] == report_lib.SUCCESS)
+ toco_success = sum(1 for x in convert_report
+ if x[1]["toco"] == report_lib.SUCCESS)
+ percent = 0
+ if tf_success > 0:
+ percent = float(toco_success) / float(tf_success) * 100.
+ tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs "
+ " and %d TOCO converted graphs (%.1f%%"), zip_path,
+ total_conversions, tf_success, toco_success, percent)
+
+ if not FLAGS.ignore_toco_errors and toco_errors > 0:
+ raise RuntimeError(
+ "Found %d errors while generating toco models" % toco_errors)
+
+
+def make_pool_tests(pool_op_in):
+ """Make a set of tests to do average pooling.
+
+ Args:
+ pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool`.
+
+ Returns:
+ A function representing the true generator (after curried pool_op_in).
+ """
+
+ pool_op = pool_op_in
+
+ def f(zip_path):
+ """Actual function that generates examples.
+
+ Args:
+ zip_path: path to write zip to.
+ """
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
+ "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
+ # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]).
+ "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = pool_op(
+ input_tensor,
+ ksize=parameters["ksize"],
+ strides=parameters["strides"],
+ data_format=parameters["data_format"],
+ padding=parameters["padding"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(tf.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+ return f
+
+
+def make_relu_tests(zip_path):
+ """Make a set of tests to do relu."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+ [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.relu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_relu1_tests(zip_path):
+ """Make a set of tests to do relu1."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ # Note that the following is not supported:
+ # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0))
+ out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0))
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-3, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_relu6_tests(zip_path):
+ """Make a set of tests to do relu6."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.relu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-3, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+# This function tests various TensorFLow functions that generates Const op,
+# including `tf.ones`, `tf.zeros` and random functions.
+def make_constant_tests(zip_path):
+ """Make a set of tests to do constant ops."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]],
+ }]
+
+ def build_graph(parameters):
+ # Since Toco & Tflite can't have a single constant op in the entire graph,
+ # this test adds a zero tesnor with a constant op tensor.
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape"])
+ out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1
+ return [input1], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = np.zeros(parameters["input_shape"],
+ dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
+ return [input1], sess.run(outputs, feed_dict={inputs[0]: input1})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_add_tests(zip_path):
+ """Make a set of tests to do add with and without broadcast."""
+
+ # These parameters are split because we don't support broadcasting.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[5]],
+ "input_shape_2": [[5]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[3]],
+ }]
+
+ def build_graph(parameters):
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape_1"])
+ input2 = tf.placeholder(dtype=parameters["dtype"], name="input2",
+ shape=parameters["input_shape_2"])
+ out = tf.add(input1, input2)
+ return [input1, input2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_1"])
+ input2 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_2"])
+ return [input1, input2], sess.run(
+ outputs, feed_dict={
+ inputs[0]: input1,
+ inputs[1]: input2
+ })
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_mul_tests(zip_path):
+ """Make a set of tests to do mul with and without broadcast."""
+
+ # These parameters are split because we don't support broadcasting.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[5]],
+ "input_shape_2": [[5]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[3]],
+ }]
+
+ def build_graph(parameters):
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape_1"])
+ input2 = tf.placeholder(dtype=parameters["dtype"], name="input2",
+ shape=parameters["input_shape_2"])
+ out = tf.multiply(input1, input2)
+ return [input1, input2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_1"])
+ input2 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_2"])
+ return [input1, input2], sess.run(
+ outputs, feed_dict={inputs[0]: input1,
+ inputs[1]: input2})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_global_batch_norm_tests(zip_path):
+ """Make a set of tests to do batch_norm_with_global_normalization."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]],
+ "epsilon": [0.1, 0.0001],
+ "scale_after": [True, False],
+ }]
+
+ def build_graph(parameters):
+ """Build the global batch norm testing graph."""
+ input_shape = parameters["input_shape"]
+ scale_shape = input_shape[3]
+
+ scale = create_tensor_data(parameters["dtype"], scale_shape)
+ offset = create_tensor_data(parameters["dtype"], scale_shape)
+ mean = create_tensor_data(parameters["dtype"], scale_shape)
+ variance = create_tensor_data(parameters["dtype"], scale_shape)
+
+ x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ x_norm = tf.nn.batch_norm_with_global_normalization(
+ x, mean, variance, scale, offset,
+ parameters["epsilon"], parameters["scale_after"])
+
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.add(input_tensor, x_norm)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_fused_batch_norm_tests(zip_path):
+ """Make a set of tests to do fused_batch_norm."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2]],
+ "epsilon": [0.001, 0.1],
+ }]
+
+ def build_graph(parameters):
+ """Build the testing graph for fused batch normalization."""
+ input_shape = parameters["input_shape"]
+ scale_shape = input_shape[3]
+
+ scale = create_tensor_data(parameters["dtype"], scale_shape)
+ offset = create_tensor_data(parameters["dtype"], scale_shape)
+ mean = create_tensor_data(parameters["dtype"], scale_shape)
+ variance = create_tensor_data(parameters["dtype"], scale_shape)
+
+ x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ [x_norm, _, _] = tf.nn.fused_batch_norm(
+ x, scale, offset, mean, variance,
+ parameters["epsilon"], data_format="NHWC", is_training=False)
+
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.add(input_tensor, x_norm)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_conv_tests(zip_path):
+ """Make a set of tests to do convolution."""
+
+ test_parameters = [{
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_shape": [[1, 1, 3, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }, {
+ "input_shape": [[2, 14, 14, 2]],
+ "filter_shape": [[6, 6, 2, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ filter_values = create_tensor_data(np.float32, parameters["filter_shape"])
+ out = tf.nn.conv2d(input_tensor, filter_values,
+ strides=parameters["strides"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(np.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_depthwiseconv_tests(zip_path):
+ """Make a set of tests to do convolution."""
+
+ # Tensorflow only supports equal strides
+ test_parameters = [{
+ "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
+ "filter_size": [[1, 1], [1, 2], [3, 3]],
+ "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "channel_multiplier": [1, 2],
+ "rate": [[1, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"],
+ }, {
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_size": [[1, 1]],
+ "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "channel_multiplier": [2],
+ "rate": [[2, 2]], # Only [1, 1] is supported
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ }]
+
+ def build_graph(parameters):
+ """Build a depthwise conv graph given `parameters`."""
+ input_shape = parameters["input_shape"]
+ filter_size = parameters["filter_size"]
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=input_shape)
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]]
+ filter_values = create_tensor_data(np.float32, filter_shape)
+ out = tf.nn.depthwise_conv2d(
+ input_tensor, filter_values,
+ strides=parameters["strides"],
+ rate=parameters["rate"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(np.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_concatenation_tests(zip_path):
+ """Make a set of tests to do concatenatinon."""
+
+ test_parameters = [{
+ "base_shape": [[1, 3, 4, 3], [3, 4]],
+ "num_tensors": [1, 2, 3, 4, 5, 6],
+ "axis": [0, 1, 2, 3],
+ }]
+
+ def get_shape(parameters, delta):
+ """Return a tweaked version of 'base_shape'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ if axis < len(shape):
+ shape[axis] += delta
+ return shape
+
+ def build_graph(parameters):
+ all_tensors = []
+ for n in range(0, parameters["num_tensors"]):
+ input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n),
+ shape=get_shape(parameters, n))
+ all_tensors.append(input_tensor)
+ out = tf.concat(all_tensors, parameters["axis"])
+ return all_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ all_values = []
+ for n in range(0, parameters["num_tensors"]):
+ input_values = create_tensor_data(np.float32,
+ get_shape(parameters, n))
+ all_values.append(input_values)
+ return all_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, all_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_fully_connected_tests(zip_path):
+ """Make a set of tests to do fully_connected."""
+
+ test_parameters = [{
+ "shape1": [[3, 3]],
+ "shape2": [[3, 3]],
+ "transpose_a": [True, False],
+ "transpose_b": [True, False],
+ }, {
+ "shape1": [[4, 4], [1, 4], [4]],
+ "shape2": [[4, 4], [4, 1], [4]],
+ "transpose_a": [False],
+ "transpose_b": [False],
+ }, {
+ "shape1": [[40, 37]],
+ "shape2": [[37, 40]],
+ "transpose_a": [False],
+ "transpose_b": [False],
+
+ }]
+
+ def build_graph(parameters):
+ input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1",
+ shape=parameters["shape1"])
+ input_tensor2 = create_tensor_data(np.float32, parameters["shape2"])
+ out = tf.matmul(input_tensor1, input_tensor2,
+ transpose_a=parameters["transpose_a"],
+ transpose_b=parameters["transpose_b"])
+ return [input_tensor1], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"])
+ return [input_values1], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values1])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_l2norm_tests(zip_path):
+ """Make a set of tests to do l2norm."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ "dim": [0, 1, 2, 3, [2, 3], -2],
+ "epsilon": [None, 1e-12, 1e-3],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ if parameters["epsilon"]:
+ out = tf.nn.l2_normalize(
+ input_tensor, parameters["dim"], epsilon=parameters["epsilon"])
+ else:
+ out = tf.nn.l2_normalize(input_tensor, parameters["dim"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_local_response_norm_tests(zip_path):
+ """Make a set of tests to do local_response_norm."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]],
+ "depth_radius": [None, 0, 1, 3, 4, 5],
+ "bias": [None, 0.1, 0.3, -0.1],
+ "alpha": [None, 1, 2, -3],
+ "beta": [None, 0.5, 0.25, 2],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.local_response_normalization(
+ input_tensor, depth_radius=parameters["depth_radius"],
+ bias=parameters["bias"], alpha=parameters["alpha"],
+ beta=parameters["beta"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_reshape_tests(zip_path):
+ """Make a set of tests to do reshape."""
+
+ # Alll shapes below are suitable for tensors with 420 elements.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]],
+ "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.reshape(input_tensor, shape=parameters["output_shape"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_resize_bilinear_tests(zip_path):
+ """Make a set of tests to do resize_bilinear."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
+ "size": [[1, 1], [4, 3], [2, 2], [5, 6]],
+ "align_corners": [None, True, False],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.image.resize_bilinear(input_tensor, size=parameters["size"],
+ align_corners=parameters["align_corners"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_sigmoid_tests(zip_path):
+ """Make a set of tests to do sigmoid."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.sigmoid(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_softmax_tests(zip_path):
+ """Make a set of tests to do softmax."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 3, 4, 3], [2, 3]],
+ "dim": [-1, 0],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape": [[4, 7]],
+ "dim": [-1, 1],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.nn.softmax(input_tensor, dim=parameters["dim"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_space_to_depth_tests(zip_path):
+ """Make a set of tests to do space_to_depth."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64],
+ "input_shape": [[2, 12, 24, 1]],
+ "block_size": [2, 3, 4],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
+ """Given an input perform a sequence of TensorFlow ops to produce l2pool."""
+ return tf.sqrt(tf.nn.avg_pool(
+ tf.square(input_tensor), ksize=ksize, strides=strides,
+ padding=padding, data_format=data_format))
+
+
+# Toco binary path provided by the generate rule.
+bin_path = None
+
+
+def main(unused_args):
+ global bin_path
+ def mkdir_if_not_exist(x):
+ if not os.path.isdir(x):
+ os.mkdir(x)
+ if not os.path.isdir(x):
+ raise RuntimeError("Failed to create dir %r" % x)
+
+ if FLAGS.type == "zipped":
+ opstest_path = os.path.join(FLAGS.output_path)
+ mkdir_if_not_exist(opstest_path)
+ def _path(filename):
+ return os.path.join(opstest_path, filename)
+
+ dispatch = {
+ "control_dep.zip": make_control_dep_tests,
+ "add.zip": make_add_tests,
+ "conv.zip": make_conv_tests,
+ "constant.zip": make_constant_tests,
+ "depthwiseconv.zip": make_depthwiseconv_tests,
+ "concat.zip": make_concatenation_tests,
+ "fully_connected.zip": make_fully_connected_tests,
+ "global_batch_norm.zip": make_global_batch_norm_tests,
+ "fused_batch_norm.zip": make_fused_batch_norm_tests,
+ "l2norm.zip": make_l2norm_tests,
+ "local_response_norm.zip": make_local_response_norm_tests,
+ "mul.zip": make_mul_tests,
+ "relu.zip": make_relu_tests,
+ "relu1.zip": make_relu1_tests,
+ "relu6.zip": make_relu6_tests,
+ "l2_pool.zip": make_pool_tests(make_l2_pool),
+ "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
+ "max_pool.zip": make_pool_tests(tf.nn.max_pool),
+ "reshape.zip": make_reshape_tests,
+ "resize_bilinear.zip": make_resize_bilinear_tests,
+ "sigmoid.zip": make_sigmoid_tests,
+ "softmax.zip": make_softmax_tests,
+ "space_to_depth.zip": make_space_to_depth_tests,
+ }
+ out = FLAGS.zip_to_output
+ bin_path = FLAGS.toco
+ if out in dispatch:
+ dispatch[out](_path(out))
+ else:
+ raise RuntimeError("Invalid zip to output %r" % out)
+
+ else:
+ raise RuntimeError("Invalid argument for type of generation.")
+
+
+if __name__ == "__main__":
+ FLAGS, unparsed = parser.parse_known_args()
+
+ if unparsed:
+ print("Usage: %s <path out> zipped <zip file to generate>")
+ else:
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/contrib/lite/testing/generate_examples_report.py
new file mode 100644
index 0000000000..7bcf8cd86a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generate_examples_report.py
@@ -0,0 +1,125 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Make HTML tables that report where TF and TOCO failed to convert models.
+
+This is primarily used by generate_examples.py. See it or
+`make_report_table` for more details on usage.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cgi
+import json
+
+FAILED = "FAILED"
+SUCCESS = "SUCCESS"
+NOTRUN = "NOTRUN"
+
+
+def make_report_table(fp, title, reports):
+ """Make an HTML report of the success/failure reports.
+
+ Args:
+ fp: File-like object in which to put the html.
+ title: "Title of the zip file this pertains to."
+ reports: a list of conversion attempts. (report_args, report_vals) i.e.
+ ({"shape": [1,2,3], "type": "tf.float32"},
+ {"tf": "SUCCESS", "toco": "FAILURE", "toco_log": "Unsupported type.",
+ "tf_log": ""})
+ """
+ # sort reports by if TOCO failure and then TF failure (reversed)
+ reports.sort(key=lambda x: x[1]["toco"], reverse=False)
+ reports.sort(key=lambda x: x[1]["tf"], reverse=True)
+ def result_cell(x, row, col):
+ """Produce a cell with the condition string `x`."""
+ s = cgi.escape(repr(x), quote=True)
+ color = "#44ff44" if x == SUCCESS else (
+ "#ff4444" if x == FAILED else "#eeeeee")
+ handler = "ShowLog(%d, %d)" % (row, col)
+ fp.write("<td style='background-color: %s' onclick='%s'>%s</td>\n" % (
+ color, handler, s))
+
+ fp.write("""<html>
+<head>
+<title>tflite report</title>
+<style>
+body { font-family: Arial; }
+th { background-color: #555555; color: #eeeeee; }
+td { vertical-align: top; }
+td.horiz {width: 50%;}
+pre { white-space: pre-wrap; word-break: keep-all; }
+table {width: 100%;}
+</style>
+</head>
+""")
+ # Write the log data to a javascript variable and also make a function
+ # in javascript to show the log when an item is clicked.
+ fp.write("<script> \n")
+ fp.write("""
+function ShowLog(row, col) {
+
+var log = document.getElementById("log");
+log.innerHTML = "<pre>" + data[row][col] + "</pre>";
+}
+""")
+ fp.write("var data = \n")
+ fp.write(json.dumps([[cgi.escape(x[1]["tf_log"], quote=True),
+ cgi.escape(x[1]["toco_log"], quote=True)]
+ for x in reports]))
+ fp.write(";</script>\n")
+
+ # Write the main table and use onclick on the items that have log items.
+ fp.write("""
+<body>
+<h1>TOCO Conversion</h1>
+<h2>%s</h2>
+""" % title)
+
+ # Get a list of keys that are in any of the records.
+ param_keys = {}
+ for params, _ in reports:
+ for k in params.keys():
+ param_keys[k] = True
+
+ fp.write("<table>\n")
+ fp.write("<tr><td class='horiz'>\n")
+ fp.write("<div style='height:1000px; overflow:auto'>\n")
+ fp.write("<table>\n")
+ fp.write("<tr>\n")
+ for p in param_keys:
+ fp.write("<th>%s</th>\n" % cgi.escape(p, quote=True))
+ fp.write("<th>TensorFlow</th>\n")
+ fp.write("<th>TOCO</th>\n")
+ fp.write("</tr>\n")
+ for idx, (params, vals) in enumerate(reports):
+ fp.write("<tr>\n")
+ for p in param_keys:
+ fp.write(" <td>%s</td>\n" % cgi.escape(repr(params[p]), quote=True))
+
+ result_cell(vals["tf"], idx, 0)
+ result_cell(vals["toco"], idx, 1)
+ fp.write("</tr>\n")
+ fp.write("</table>\n")
+ fp.write("</div>\n")
+ fp.write("</td>\n")
+ fp.write("<td class='horiz' id='log'></td></tr>\n")
+ fp.write("</table>\n")
+ fp.write("<script>\n")
+ fp.write("</script>\n")
+ fp.write("""
+ </body>
+ </html>
+ """)
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
new file mode 100644
index 0000000000..e7df97ee54
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -0,0 +1,279 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <cstdio>
+#include <cstdlib>
+#include <fstream>
+#include <map>
+#include <sstream>
+#include <gtest/gtest.h>
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+bool FLAGS_ignore_known_bugs = true;
+} // namespace
+
+namespace tflite {
+namespace testing {
+
+// TensorFlow system environment for file system called.
+tensorflow::Env* env = tensorflow::Env::Default();
+
+// List of tests that are expected to fail when
+// --test_arg=--ignore_known_bugs=false
+// Key is a substring of the test name and value is a bug number.
+// TODO(ahentz): make sure we clean this list up frequently.
+std::map<string, string> kBrokenTests = {
+ // Add doesn't support broadcasting.
+ {R"(addd.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(muld.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+
+ // Add only supports float32. (and "constant" tests use Add)
+ {R"(addd.*int32)", "68808744"},
+ {R"(constant.*int32)", "68808744"},
+ {R"(mul.*int32)", "68808744"},
+
+ // Toco or TFLite has a bug to deal with some constant functions with
+ // more than 1 element.
+ {R"(constant.*input_shape=\[(2|2,2,2,2)\])", "68721522"},
+
+ // L2Norm only supports 4D tensors.
+ {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.\])", "67963684"},
+ {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
+
+ // L2Norm only works for dim=-1.
+ {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+
+ // ResizeBilinear looks completely incompatible with Tensorflow
+ {R"(resize_bilinear)", "67964336"},
+};
+
+// Allows test data to be unzipped into a temporary directory and makes
+// sure those temporary directories are removed later.
+class ZipEnvironment : public ::testing::Environment {
+ public:
+ ~ZipEnvironment() override {}
+
+ // Delete all temporary directories on teardown.
+ void TearDown() override {
+ for (const auto& dir : temporary_directories_) {
+ tensorflow::int64 undeleted_dirs, undeleted_files;
+ TF_CHECK_OK(
+ env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files));
+ }
+ temporary_directories_.clear();
+ }
+
+ // Unzip `zip` file into a new temporary directory `out_dir`.
+ tensorflow::Status UnZip(const std::string& zip, std::string* out_dir) {
+ string dir;
+ TF_CHECK_OK(MakeTemporaryDirectory(&dir));
+ tensorflow::SubProcess proc;
+ std::string unzip_binary =
+ "/usr/bin/unzip";
+ proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip.c_str()});
+ proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
+ proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
+ if (!proc.Start())
+ return tensorflow::Status(tensorflow::error::UNKNOWN,
+ "unzip couldn't start");
+ string out, err;
+ int status = proc.Communicate(nullptr, &out, &err);
+ if (WEXITSTATUS(status) == 0) {
+ *out_dir = dir;
+ return tensorflow::Status::OK();
+ } else {
+ return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed");
+ }
+ }
+
+ private:
+ // Make a temporary directory and return its name in `temporary`.
+ tensorflow::Status MakeTemporaryDirectory(string* temporary) {
+ if (env->LocalTempFilename(temporary)) {
+ TF_CHECK_OK(env->CreateDir(*temporary));
+ temporary_directories_.push_back(*temporary);
+ return tensorflow::Status::OK();
+ }
+ return tensorflow::Status(tensorflow::error::UNKNOWN,
+ "make temporary directory failed");
+ }
+
+ std::vector<string> temporary_directories_;
+};
+
+// Return the singleton zip_environment.
+ZipEnvironment* zip_environment() {
+ static ZipEnvironment* env = new ZipEnvironment;
+ return env;
+}
+
+// Read the manifest.txt out of the unarchived zip file. Specifically
+// `original_file` is the original zip file for error messages. `dir` is
+// the temporary directory where the zip file has been unarchived and
+// `test_paths` is the list of test prefixes that were in the manifest.
+// Note, it is an error for a manifest to contain no tests.
+tensorflow::Status ReadManifest(const std::string& original_file,
+ const std::string& dir,
+ std::vector<std::string>* test_paths) {
+ // Read the newline delimited list of entries in the manifest.
+ std::ifstream manifest_fp(dir + "/manifest.txt");
+ std::string manifest((std::istreambuf_iterator<char>(manifest_fp)),
+ std::istreambuf_iterator<char>());
+ size_t pos = 0;
+ int added = 0;
+ while (true) {
+ size_t end_pos = manifest.find("\n", pos);
+ if (end_pos == std::string::npos) break;
+ std::string filename = manifest.substr(pos, end_pos - pos);
+ test_paths->push_back(dir + "/" + filename);
+ pos = end_pos + 1;
+ added += 1;
+ }
+ if (!added) {
+ std::string message = "Test had no examples: " + original_file;
+ return tensorflow::Status(tensorflow::error::UNKNOWN, message.c_str());
+ }
+ return tensorflow::Status::OK();
+}
+
+// Get a list of tests from a zip file `zip_file_name`.
+std::vector<std::string> UnarchiveZipAndFindTestNames(
+ const std::string& zip_file_name) {
+ std::string zip_file = ::tensorflow::testing::TensorFlowSrcRoot() +
+ "/contrib/lite/testing/optest/" + zip_file_name;
+ std::string decompress_tmp_dir;
+ TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir));
+ std::vector<std::string> stuff;
+ TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ return stuff;
+}
+
+class OpsTest : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(OpsTest, RunStuff) {
+ std::string test_path = GetParam();
+ std::string tflite_file = test_path + ".bin";
+ std::string tflite_examples = test_path + ".inputs";
+ auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file.c_str());
+ std::unique_ptr<tflite::Interpreter> interpreter;
+
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+ ASSERT_EQ(tflite::InterpreterBuilder(*model, builtins)(&interpreter),
+ kTfLiteOk);
+
+ std::vector<tflite::testing::Example> examples;
+ ASSERT_EQ(tflite::testing::ParseExamples(tflite_examples.c_str(), &examples),
+ kTfLiteOk);
+
+ string bug_number;
+ for (const auto& p : kBrokenTests) {
+ if (RE2::PartialMatch(test_path, p.first)) {
+ bug_number = p.second;
+ }
+ }
+
+ for (const auto& example : examples) {
+ ASSERT_EQ(interpreter->inputs().size(), example.inputs.size());
+ auto result = [&]() {
+ TF_LITE_ENSURE_STATUS(FeedExample(interpreter.get(), example));
+ TF_LITE_ENSURE_STATUS(interpreter->Invoke());
+ TF_LITE_ENSURE_STATUS(CheckOutputs(interpreter.get(), example));
+ return kTfLiteOk;
+ }();
+
+ if (bug_number.empty()) {
+ ASSERT_EQ(result, kTfLiteOk);
+ } else {
+ if (FLAGS_ignore_known_bugs) {
+ ASSERT_EQ(result, kTfLiteError)
+ << "Not failing as expected dut to http://b/" << bug_number;
+ } else {
+ ASSERT_EQ(result, kTfLiteOk)
+ << "Possibly due to http://b/" << bug_number;
+ }
+ }
+ }
+}
+
+// Instantiate a test. This assumes `zip_base`.zip is a declared data file
+// of this test.
+#define INSTANTIATE_TESTS(zip_base) \
+ INSTANTIATE_TEST_CASE_P( \
+ zip_base, OpsTest, \
+ ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip")));
+
+INSTANTIATE_TESTS(add)
+INSTANTIATE_TESTS(avg_pool)
+INSTANTIATE_TESTS(concat)
+INSTANTIATE_TESTS(constant)
+INSTANTIATE_TESTS(control_dep)
+INSTANTIATE_TESTS(conv)
+INSTANTIATE_TESTS(depthwiseconv)
+INSTANTIATE_TESTS(fully_connected)
+INSTANTIATE_TESTS(fused_batch_norm)
+INSTANTIATE_TESTS(global_batch_norm)
+INSTANTIATE_TESTS(l2norm)
+INSTANTIATE_TESTS(l2_pool)
+INSTANTIATE_TESTS(local_response_norm)
+INSTANTIATE_TESTS(max_pool)
+INSTANTIATE_TESTS(mul)
+INSTANTIATE_TESTS(relu)
+INSTANTIATE_TESTS(relu1)
+INSTANTIATE_TESTS(relu6)
+INSTANTIATE_TESTS(reshape)
+INSTANTIATE_TESTS(resize_bilinear)
+INSTANTIATE_TESTS(sigmoid)
+INSTANTIATE_TESTS(softmax)
+INSTANTIATE_TESTS(space_to_depth)
+
+} // namespace testing
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment());
+
+ std::vector<tensorflow::Flag> flags = {tensorflow::Flag(
+ "ignore_known_bugs", &FLAGS_ignore_known_bugs,
+ "If a particular model is affected by a known bug, the "
+ "corresponding test should expect the outputs to not match.")};
+ bool success = tensorflow::Flags::Parse(&argc, argv, flags);
+ if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return 1;
+ }
+
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/contrib/lite/testing/message.cc
new file mode 100644
index 0000000000..03fae4bb86
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/message.h"
+
+#include <stack>
+
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+
+namespace tflite {
+namespace testing {
+
+// A token processor that builds messages and forward calls to the current
+// message object. Place a new message at the top of the stack when it start
+// and remove it when it is finished.
+class MessageStack : public TokenProcessor {
+ public:
+ // Start a new MessageStack with the given first_node, which will be used to
+ // process freestanding fields and submessages.
+ explicit MessageStack(Message* first_node) {
+ nodes_.push(first_node);
+ valid_ = true;
+ }
+
+ void ConsumeToken(std::string* token) override {
+ if (!valid_) return;
+ Message* current_node = nodes_.top();
+ if (*token == "{") {
+ // This is the beginning of a new message, names after the previous token.
+ if (previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ nodes_.push(current_node ? current_node->AddChild(previous_token_)
+ : nullptr);
+ previous_token_.clear();
+ } else if (*token == "}") {
+ // A message is being completed. There should be no previous token. Note
+ // that the top-level message never closes, so we should always have at
+ // least one entry in the stack.
+ if (nodes_.size() == 1 || !previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ if (current_node) {
+ current_node->Finish();
+ }
+ nodes_.pop();
+ } else if (*token == ":") {
+ // We reached the end of the 'key' portion of a field. Store the token
+ // until we have the 'value' portion.
+ if (previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ } else {
+ if (previous_token_.empty()) {
+ previous_token_.swap(*token);
+ } else {
+ // This is the 'value' portion of a field. The previous token is the
+ // 'key'.
+ if (current_node) {
+ current_node->SetField(previous_token_, *token);
+ }
+ previous_token_.clear();
+ }
+ }
+ }
+
+ bool valid() const { return valid_; }
+
+ private:
+ std::stack<Message*> nodes_;
+ std::string previous_token_;
+ bool valid_;
+};
+
+bool Message::Read(std::istream* input, Message* message) {
+ MessageStack stack(message);
+ Tokenize(input, &stack);
+ return stack.valid();
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h
new file mode 100644
index 0000000000..78ef7e2cbe
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace testing {
+
+// A Message is a textual protobuf-like structure that looks like:
+// tag {
+// f : "values"
+// child {
+// a : 1
+// }
+// }
+// This class provides the framework for processing message but does not
+// associate any particular behavior to fields and submessage. In order
+// to properly parse a stream this class must be derived.
+class Message {
+ public:
+ // Reads a stream, tokenizes it and create a new message under the given
+ // top-level message. Returns true if the parsing succeeded.
+ static bool Read(std::istream* input, Message* message);
+
+ Message() {}
+ virtual ~Message() {}
+
+ // Called when a new field is found. For example, when:
+ // f : "values"
+ // is found, it triggers:
+ // SetField("f", "values");
+ virtual void SetField(const std::string& name, const std::string& value) {}
+
+ // Called when a submessage is started. For example, when:
+ // child {
+ // is found, it triggers
+ // AddChild("child");
+ // If nullptr is returned, the contents of the submessage will be ignored.
+ // Otherwise, the returned Message will be used to handle new fields and new
+ // submessages. The caller should not take ownership of the returned pointer.
+ virtual Message* AddChild(const std::string& name) { return nullptr; }
+
+ // Called when a submessage is completed, that is, whenever a '}' is found.
+ virtual void Finish() {}
+
+ protected:
+ // Takes ownership of the given pointer. Subclasses can use this method if
+ // they don't want to implement their own ownership semantics.
+ Message* Store(Message* n) {
+ children_.emplace_back(n);
+ return n;
+ }
+
+ // Returns a list of all owned submessages.
+ const std::vector<std::unique_ptr<Message>>& Children() const {
+ return children_;
+ }
+
+ private:
+ std::vector<std::unique_ptr<Message>> children_;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/contrib/lite/testing/message_test.cc
new file mode 100644
index 0000000000..fb6a49bd6f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message_test.cc
@@ -0,0 +1,121 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/message.h"
+
+#include <map>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// A hierarchical, key-value store.
+class TestMessage : public Message {
+ public:
+ TestMessage() {}
+ explicit TestMessage(const std::string& text_to_parse) {
+ std::stringstream ss(text_to_parse);
+ finished_ = Message::Read(&ss, this);
+ }
+ void SetField(const std::string& name, const std::string& value) override {
+ fields_[name] = value;
+ }
+ Message* AddChild(const std::string& name) override {
+ TestMessage* m = new TestMessage;
+ m->name_ = name;
+ return Store(m);
+ }
+ void Finish() override { finished_ = true; }
+
+ int NumChildren() const { return Children().size(); }
+
+ const TestMessage* GetChild(int i) const {
+ return dynamic_cast<TestMessage*>(Children()[i].get());
+ }
+
+ int NumFields() const { return fields_.size(); }
+ const std::string& GetField(const std::string& key) const {
+ return fields_.at(key);
+ }
+
+ const std::string& name() const { return name_; }
+ bool finished() const { return finished_; }
+
+ protected:
+ std::string name_;
+ std::map<std::string, std::string> fields_;
+ bool finished_ = false;
+};
+
+TEST(MessageTest, Simple) {
+ TestMessage message("x{a:1 b:2} y{} z{c:3} d:4");
+ ASSERT_TRUE(message.finished());
+
+ ASSERT_EQ(message.NumFields(), 1);
+ EXPECT_EQ(message.GetField("d"), "4");
+
+ ASSERT_EQ(message.NumChildren(), 3);
+
+ auto* x = message.GetChild(0);
+ EXPECT_EQ(x->name(), "x");
+ ASSERT_EQ(x->NumFields(), 2);
+ EXPECT_EQ(x->GetField("a"), "1");
+ EXPECT_EQ(x->GetField("b"), "2");
+
+ auto* y = message.GetChild(1);
+ EXPECT_EQ(y->name(), "y");
+ ASSERT_EQ(y->NumFields(), 0);
+
+ auto* z = message.GetChild(2);
+ EXPECT_EQ(z->name(), "z");
+ ASSERT_EQ(z->NumFields(), 1);
+ EXPECT_EQ(z->GetField("c"), "3");
+}
+
+TEST(MessageTest, Unnamed) {
+ TestMessage message("x{c:3} {} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 1);
+}
+
+TEST(MessageTest, TooManyBraces) {
+ TestMessage message("x{c:3} } y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 1);
+}
+
+TEST(MessageTest, LeftoverToken) {
+ TestMessage message("x{c:3} z{test} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+TEST(MessageTest, MissingKey) {
+ TestMessage message("x{c:3} z{:test} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+TEST(MessageTest, MissingValue) {
+ TestMessage message("x{c:3} z{test:} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/contrib/lite/testing/nnapi_example.cc
new file mode 100644
index 0000000000..74f6cfc3de
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/nnapi_example.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// NOTE: this is an example driver that converts a tflite model to TensorFlow.
+// This is an example that will be integrated more tightly into tflite in
+// the future.
+//
+// Usage: bazel run -c opt \
+// tensorflow/contrib/lite/nnapi:nnapi_example -- <filename>
+//
+#include <cstdarg>
+#include <cstdio>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+
+// TODO(aselle): FATAL leaves resources hanging.
+void FATAL(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ vfprintf(stderr, format, args);
+ va_end(args);
+ fflush(stderr);
+ exit(1);
+}
+
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+void Interpret(const char* filename, const char* examples_filename,
+ bool use_nnapi) {
+ // TODO(aselle): Resize of input image should go here
+ // ...
+ // For now I am allocating all tensors. This means I am fixed size.
+ // So I am not using the variable size ability yet.
+ fprintf(stderr, "example file %s\n", examples_filename);
+ std::vector<tflite::testing::Example> examples;
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::ParseExamples(examples_filename, &examples));
+
+ for (const tflite::testing::Example& example : examples) {
+ auto model = tflite::FlatBufferModel::BuildFromFile(filename);
+ if (!model) FATAL("Cannot read file %s\n", filename);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+
+ CHECK_TFLITE_SUCCESS(
+ tflite::InterpreterBuilder(*model, builtins)(&interpreter));
+
+ printf("Use nnapi is set to: %d\n", use_nnapi);
+ interpreter->UseNNAPI(use_nnapi);
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::FeedExample(interpreter.get(), example));
+
+ {
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ *p = 0;
+ }
+ }
+ }
+ interpreter->Invoke();
+
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::CheckOutputs(interpreter.get(), example));
+
+ printf("Result:\n");
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ printf(" %f", *p);
+ }
+ }
+ }
+}
+
+int main(int argc, char* argv[]) {
+ bool use_nnapi = true;
+ if (argc == 4) {
+ use_nnapi = strcmp(argv[3], "1") == 0 ? true : false;
+ }
+ if (argc < 3) {
+ fprintf(stderr,
+ "Compiled " __DATE__ __TIME__
+ "\n"
+ "Usage!!!: %s <tflite model> <examples to test> "
+ "{ use nn api i.e. 0,1}\n",
+ argv[0]);
+ return 1;
+ }
+ Interpret(argv[1], argv[2], use_nnapi);
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc
new file mode 100644
index 0000000000..2b67052cad
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/parse_testdata.cc
@@ -0,0 +1,335 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Parses tflite example input data.
+// Format is ASCII
+// TODO(aselle): Switch to protobuf, but the android team requested a simple
+// ASCII file.
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+#include <streambuf>
+
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/testing/message.h"
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// Fatal error if parse error occurs
+#define PARSE_CHECK_EQ(filename, current_line, x, y) \
+ if ((x) != (y)) { \
+ fprintf(stderr, "Parse Error @ %s:%d\n File %s\n Line %d, %s != %s\n", \
+ __FILE__, __LINE__, filename, current_line + 1, #x, #y); \
+ return kTfLiteError; \
+ }
+
+// Breakup a "," delimited line into a std::vector<std::string>.
+// This is extremely inefficient, and just used for testing code.
+// TODO(aselle): replace with absl when we use it.
+std::vector<std::string> ParseLine(const std::string& line) {
+ size_t pos = 0;
+ std::vector<std::string> elements;
+ while (true) {
+ size_t end = line.find(',', pos);
+ if (end == std::string::npos) {
+ elements.push_back(line.substr(pos));
+ break;
+ } else {
+ elements.push_back(line.substr(pos, end - pos));
+ }
+ pos = end + 1;
+ }
+ return elements;
+}
+
+} // namespace
+
+// Given a `filename`, produce a vector of Examples corresopnding
+// to test cases that can be applied to a tflite model.
+TfLiteStatus ParseExamples(const char* filename,
+ std::vector<Example>* examples) {
+ std::ifstream fp(filename);
+ if (!fp.good()) {
+ fprintf(stderr, "Could not read '%s'\n", filename);
+ return kTfLiteError;
+ }
+ std::string str((std::istreambuf_iterator<char>(fp)),
+ std::istreambuf_iterator<char>());
+ size_t pos = 0;
+
+ // \n and , delimit parse a file.
+ std::vector<std::vector<std::string>> csv;
+ while (true) {
+ size_t end = str.find('\n', pos);
+
+ if (end == std::string::npos) {
+ csv.emplace_back(ParseLine(str.substr(pos)));
+ break;
+ }
+ csv.emplace_back(ParseLine(str.substr(pos, end - pos)));
+ pos = end + 1;
+ }
+
+ int current_line = 0;
+ PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases");
+ int example_count = std::stoi(csv[0][1]);
+ current_line++;
+
+ auto parse_tensor = [&filename, &current_line,
+ &csv](FloatTensor* tensor_ptr) {
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype");
+ current_line++;
+ // parse shape
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape");
+ size_t elements = 1;
+ FloatTensor& tensor = *tensor_ptr;
+
+ for (size_t i = 1; i < csv[current_line].size(); i++) {
+ const auto& shape_part_to_parse = csv[current_line][i];
+ if (shape_part_to_parse.empty()) {
+ // Case of a 0-dimensional shape
+ break;
+ }
+ int shape_part = std::stoi(shape_part_to_parse);
+ elements *= shape_part;
+ tensor.shape.push_back(shape_part);
+ }
+ current_line++;
+ // parse data
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1,
+ elements);
+ for (size_t i = 1; i < csv[current_line].size(); i++) {
+ tensor.flat_data.push_back(std::stof(csv[current_line][i]));
+ }
+ current_line++;
+
+ return kTfLiteOk;
+ };
+
+ for (int example_idx = 0; example_idx < example_count; example_idx++) {
+ Example example;
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs");
+ int inputs = std::stoi(csv[current_line][1]);
+ current_line++;
+ // parse dtype
+ for (int input_index = 0; input_index < inputs; input_index++) {
+ example.inputs.push_back(FloatTensor());
+ TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back()));
+ }
+
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs");
+ int outputs = std::stoi(csv[current_line][1]);
+ current_line++;
+ for (int input_index = 0; input_index < outputs; input_index++) {
+ example.outputs.push_back(FloatTensor());
+ TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back()));
+ }
+ examples->emplace_back(example);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus FeedExample(tflite::Interpreter* interpreter,
+ const Example& example) {
+ // Resize inputs to match example & allocate.
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input_index = interpreter->inputs()[i];
+
+ TF_LITE_ENSURE_STATUS(
+ interpreter->ResizeInputTensor(input_index, example.inputs[i].shape));
+ }
+ TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors());
+ // Copy data into tensors.
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input_index = interpreter->inputs()[i];
+ if (float* data = interpreter->typed_tensor<float>(input_index)) {
+ for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
+ data[idx] = example.inputs[i].flat_data[idx];
+ }
+ } else if (int32_t* data =
+ interpreter->typed_tensor<int32_t>(input_index)) {
+ for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
+ data[idx] = example.inputs[i].flat_data[idx];
+ }
+ } else {
+ fprintf(stderr, "input[%zu] was not float or int data\n", i);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
+ const Example& example) {
+ constexpr double kRelativeThreshold = 1e-2f;
+ constexpr double kAbsoluteThreshold = 1e-4f;
+
+ ErrorReporter* context = DefaultErrorReporter();
+ int model_outputs = interpreter->outputs().size();
+ TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
+ for (size_t i = 0; i < interpreter->outputs().size(); i++) {
+ int output_index = interpreter->outputs()[i];
+ if (const float* data = interpreter->typed_tensor<float>(output_index)) {
+ for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
+ float computed = data[idx];
+ float reference = example.outputs[0].flat_data[idx];
+ float diff = std::abs(computed - reference);
+ bool error_is_large = false;
+ // For very small numbers, try absolute error, otherwise go with
+ // relative.
+ if (std::abs(reference) < kRelativeThreshold) {
+ error_is_large = (diff > kAbsoluteThreshold);
+ } else {
+ error_is_large = (diff > kRelativeThreshold * std::abs(reference));
+ }
+ if (error_is_large) {
+ fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
+ i, idx, data[idx], reference);
+ return kTfLiteError;
+ }
+ }
+ fprintf(stderr, "\n");
+ } else if (const int32_t* data =
+ interpreter->typed_tensor<int32_t>(output_index)) {
+ for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
+ int32_t computed = data[idx];
+ int32_t reference = example.outputs[0].flat_data[idx];
+ if (std::abs(computed - reference) > 0) {
+ fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n",
+ i, idx, data[idx], example.outputs[0].flat_data[idx]);
+ return kTfLiteError;
+ }
+ }
+ fprintf(stderr, "\n");
+ } else {
+ fprintf(stderr, "output[%zu] was not float or int data\n", i);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+// Process an 'invoke' message, triggering execution of the test runner, as
+// well as verification of outputs. An 'invoke' message looks like:
+// invoke {
+// id: xyz
+// input: 1,2,1,1,1,2,3,4
+// ouput: 4,5,6
+// }
+class Invoke : public Message {
+ public:
+ explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) {
+ expected_inputs_ = test_runner->GetInputs();
+ expected_outputs_ = test_runner->GetOutputs();
+ }
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "id") {
+ test_runner_->SetInvocationId(value);
+ } else if (name == "input") {
+ if (expected_inputs_.empty()) {
+ return test_runner_->Invalidate("Too many inputs");
+ }
+ test_runner_->SetInput(*expected_inputs_.begin(), value);
+ expected_inputs_.erase(expected_inputs_.begin());
+ } else if (name == "output") {
+ if (expected_outputs_.empty()) {
+ return test_runner_->Invalidate("Too many outputs");
+ }
+ test_runner_->SetExpectation(*expected_outputs_.begin(), value);
+ expected_outputs_.erase(expected_outputs_.begin());
+ }
+ }
+ void Finish() override {
+ test_runner_->Invoke();
+ test_runner_->CheckResults();
+ }
+
+ private:
+ std::vector<int> expected_inputs_;
+ std::vector<int> expected_outputs_;
+
+ TestRunner* test_runner_;
+};
+
+// Process an 'reshape' message, triggering resizing of the input tensors via
+// the test runner. A 'reshape' message looks like:
+// reshape {
+// input: 1,2,1,1,1,2,3,4
+// }
+class Reshape : public Message {
+ public:
+ explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) {
+ expected_inputs_ = test_runner->GetInputs();
+ }
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "input") {
+ if (expected_inputs_.empty()) {
+ return test_runner_->Invalidate("Too many inputs to reshape");
+ }
+ test_runner_->ReshapeTensor(*expected_inputs_.begin(), value);
+ expected_inputs_.erase(expected_inputs_.begin());
+ }
+ }
+
+ private:
+ std::vector<int> expected_inputs_;
+ TestRunner* test_runner_;
+};
+
+// This is the top-level message in a test file.
+class TestData : public Message {
+ public:
+ explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {}
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "load_model") {
+ test_runner_->LoadModel(value);
+ } else if (name == "init_state") {
+ test_runner_->AllocateTensors();
+ for (int id : Split<int>(value, ",")) {
+ test_runner_->ResetTensor(id);
+ }
+ }
+ }
+ Message* AddChild(const std::string& s) override {
+ if (s == "invoke") {
+ test_runner_->AllocateTensors();
+ return Store(new Invoke(test_runner_));
+ } else if (s == "reshape") {
+ return Store(new Reshape(test_runner_));
+ }
+ return nullptr;
+ }
+
+ private:
+ TestRunner* test_runner_;
+};
+
+bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) {
+ TestData test_data(test_runner);
+ Message::Read(input, &test_data);
+ return test_runner->IsValid() && test_runner->GetOverallSuccess();
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h
new file mode 100644
index 0000000000..90839fe245
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/parse_testdata.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+
+#include <vector>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+namespace tflite {
+namespace testing {
+
+// Shape and data for a float tensor
+struct FloatTensor {
+ std::vector<int> shape;
+ std::vector<float> flat_data;
+};
+
+// A prescribed input, output example
+struct Example {
+ std::vector<FloatTensor> inputs;
+ std::vector<FloatTensor> outputs;
+};
+
+// Parses an example input and output file (used for unit tests)
+TfLiteStatus ParseExamples(const char* filename,
+ std::vector<Example>* examples);
+
+// Inputs Tensors into a TensorFlow lite interpreter. Note, this will run
+// interpreter.AllocateTensors();
+TfLiteStatus FeedExample(tflite::Interpreter* interpreter, const Example&);
+
+// Check outputs against (already) evaluated result.
+TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&);
+
+// Parses a test description and feeds the given test runner with data.
+// The input format is similar to an ASCII proto:
+// // Loads model 'add.bin' from the TestRunner's model directory.
+// load_model: "add.bin"
+// // Changes the shape of inputs, provided in the same order they appear
+// // in the model.
+// reshape {
+// input: "1,224,224,3"
+// input: "1,3,4,1"
+// }
+// // Fills the given persistent tensors with zeros.
+// init_state: 0,1,2,3
+// // Invokes the interpreter with the given input and checks that it
+// // produces the expected output. Inputs and outputs should be specified in
+// // the order they appear in the model.
+// invoke {
+// input: "1,2,3,4,56"
+// input: "0.1,0.2,0.3,4.3,56.4"
+// output: "12,3,4,545,3"
+// output: "0.01,0.02"
+// }
+bool ParseAndRunTests(std::istream* input, TestRunner* test_runner);
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/contrib/lite/testing/split.cc
new file mode 100644
index 0000000000..5836f4ff04
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s,
+ const string& delimiter) {
+ std::vector<std::pair<size_t, size_t>> fields;
+ if (delimiter.length() == 0) {
+ fields.emplace_back(0, s.length());
+ return fields;
+ }
+ size_t pos = 0;
+ size_t start = 0;
+ while ((pos = s.find(delimiter, start)) != string::npos) {
+ if (pos != start) {
+ fields.emplace_back(start, pos);
+ }
+ start = pos + delimiter.length();
+ }
+ if (start != s.length()) {
+ fields.emplace_back(start, s.length());
+ }
+ return fields;
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h
new file mode 100644
index 0000000000..24071442e8
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+
+#include <cstdlib>
+#include <string>
+#include <utility>
+#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+// Splits a string based on the given delimiter string. Each pair in the
+// returned vector has the start and past-the-end positions for each of the
+// parts of the original string. Empty fields are not represented in the
+// output.
+std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s,
+ const string& delimiter);
+
+// Splits the given string and converts each part to the given T.
+template <typename T>
+std::vector<T> Split(const string& s, const string& delimiter);
+
+template <>
+inline std::vector<string> Split(const string& s, const string& delimiter) {
+ std::vector<string> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(s.substr(p.first, p.second - p.first));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<int> Split(const string& s, const string& delimiter) {
+ std::vector<int> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<float> Split(const string& s, const string& delimiter) {
+ std::vector<float> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtod(s.data() + p.first, nullptr));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
+ std::vector<uint8_t> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc
new file mode 100644
index 0000000000..3d1e25d9c7
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/split.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pair;
+
+TEST(SplitTest, SplitToPos) {
+ EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ";:"),
+ ElementsAre(Pair(0, 4), Pair(6, 12), Pair(14, 19)));
+ EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ":"),
+ ElementsAre(Pair(0, 5), Pair(6, 13), Pair(14, 19)));
+ EXPECT_THAT(SplitToPos("test", ":"), ElementsAre(Pair(0, 4)));
+ EXPECT_THAT(SplitToPos("test ", ":"), ElementsAre(Pair(0, 5)));
+ EXPECT_THAT(SplitToPos("", ":"), ElementsAre());
+ EXPECT_THAT(SplitToPos("test ", ""), ElementsAre(Pair(0, 5)));
+ EXPECT_THAT(SplitToPos("::::", ":"), ElementsAre());
+}
+
+TEST(SplitTest, SplitString) {
+ EXPECT_THAT(Split<string>("A;B;C", ";"), ElementsAre("A", "B", "C"));
+}
+
+TEST(SplitTest, SplitFloat) {
+ EXPECT_THAT(Split<float>("1.0 B 1e-5", " "), ElementsAre(1.0, 0.0, 1e-5));
+}
+
+TEST(SplitTest, SplitInt) {
+ EXPECT_THAT(Split<int>("1,-1,258", ","), ElementsAre(1, -1, 258));
+}
+
+TEST(SplitTest, SplitUint8) {
+ EXPECT_THAT(Split<uint8_t>("1,-1,258", ","), ElementsAre(1, 255, 2));
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h
new file mode 100644
index 0000000000..04ee4d9f7d
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/test_runner.h
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+// This is the base class for processing test data. Each one of the virtual
+// methods must be implemented to forward the data to the appropriate executor
+// (e.g. TF Lite's interpreter, or the NNAPI).
+class TestRunner {
+ public:
+ TestRunner() {}
+ virtual ~TestRunner() {}
+
+ // Load the given model, as a path relative to SetModelBaseDir().
+ virtual void LoadModel(const string& bin_file_path) = 0;
+
+ // Return the list of input tensors in the loaded model.
+ virtual const std::vector<int>& GetInputs() = 0;
+
+ // Return the list of output tensors in the loaded model.
+ virtual const std::vector<int>& GetOutputs() = 0;
+
+ // Prepare for a run by resize the given tensor. The given 'id' is
+ // guaranteed to be one of the ids returned by GetInputs().
+ virtual void ReshapeTensor(int id, const string& csv_values) = 0;
+
+ // Reserve memory for all tensors.
+ virtual void AllocateTensors() = 0;
+
+ // Set the given tensor to some initial state, usually zero. This is
+ // used to reset persistent buffers in a model.
+ virtual void ResetTensor(int id) = 0;
+
+ // Define the contents of the given input tensor. The given 'id' is
+ // guaranteed to be one of the ids returned by GetInputs().
+ virtual void SetInput(int id, const string& csv_values) = 0;
+
+ // Define what should be expected for an output tensor after Invoke() runs.
+ // The given 'id' is guaranteed to be one of the ids returned by
+ // GetOutputs().
+ virtual void SetExpectation(int id, const string& csv_values) = 0;
+
+ // Run the model.
+ virtual void Invoke() = 0;
+
+ // Verify that the contents of all ouputs conform to the existing
+ // expectations. Return true if there are no expectations or they are all
+ // satisfied.
+ virtual bool CheckResults() = 0;
+
+ // Set the base path for loading models.
+ void SetModelBaseDir(const string& path) {
+ model_base_dir_ = path;
+ if (path[path.length() - 1] != '/') {
+ model_base_dir_ += "/";
+ }
+ }
+
+ // Return the full path of a model.
+ string GetFullPath(const string& path) { return model_base_dir_ + path; }
+
+ // Give an id to the next invocation to make error reporting more meaningful.
+ void SetInvocationId(const string& id) { invocation_id_ = id; }
+ const string& GetInvocationId() const { return invocation_id_; }
+
+ // Invalidate the test runner, preventing it from executing any further.
+ void Invalidate(const string& error_message) {
+ error_message_ = error_message;
+ }
+ bool IsValid() const { return error_message_.empty(); }
+ const string& GetErrorMessage() const { return error_message_; }
+
+ // Handle the overall success of this test runner. This will be true if all
+ // invocations were successful.
+ void SetOverallSuccess(bool value) { overall_success_ = value; }
+ bool GetOverallSuccess() const { return overall_success_; }
+
+ protected:
+ // A helper to check of the given number of values is consistent with the
+ // number of bytes in a tensor of type T. When incompatibles sizes are found,
+ // the test runner is invalidated and false is returned.
+ template <typename T>
+ bool CheckSizes(size_t tensor_bytes, size_t num_values) {
+ size_t num_tensor_elements = tensor_bytes / sizeof(T);
+ if (num_tensor_elements != num_values) {
+ Invalidate("Expected '" + std::to_string(num_tensor_elements) +
+ "' elements for a tensor, but only got '" +
+ std::to_string(num_values) + "'");
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ string model_base_dir_;
+ string invocation_id_;
+ bool overall_success_ = true;
+
+ string error_message_;
+};
+
+} // namespace testing
+} // namespace tflite
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc
new file mode 100644
index 0000000000..f712a5347a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/test_runner_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+class ConcreteTestRunner : public TestRunner {
+ public:
+ void LoadModel(const string& bin_file_path) override {}
+ const std::vector<int>& GetInputs() override { return ids_; }
+ const std::vector<int>& GetOutputs() override { return ids_; }
+ void ReshapeTensor(int id, const string& csv_values) override {}
+ void AllocateTensors() override {}
+ void ResetTensor(int id) override {}
+ void SetInput(int id, const string& csv_values) override {}
+ void SetExpectation(int id, const string& csv_values) override {}
+ void Invoke() override {}
+ bool CheckResults() override { return true; }
+ bool CheckFloatSizes(size_t bytes, size_t values) {
+ return CheckSizes<float>(bytes, values);
+ }
+
+ private:
+ std::vector<int> ids_;
+};
+
+TEST(TestRunner, ModelPath) {
+ ConcreteTestRunner runner;
+ EXPECT_EQ(runner.GetFullPath("test.bin"), "test.bin");
+ runner.SetModelBaseDir("/tmp");
+ EXPECT_EQ(runner.GetFullPath("test.bin"), "/tmp/test.bin");
+}
+
+TEST(TestRunner, InvocationId) {
+ ConcreteTestRunner runner;
+ EXPECT_EQ(runner.GetInvocationId(), "");
+ runner.SetInvocationId("X");
+ EXPECT_EQ(runner.GetInvocationId(), "X");
+}
+
+TEST(TestRunner, Invalidation) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.IsValid());
+ EXPECT_EQ(runner.GetErrorMessage(), "");
+ runner.Invalidate("Some Error");
+ EXPECT_FALSE(runner.IsValid());
+ EXPECT_EQ(runner.GetErrorMessage(), "Some Error");
+}
+
+TEST(TestRunner, OverallSuccess) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.GetOverallSuccess());
+ runner.SetOverallSuccess(false);
+ EXPECT_FALSE(runner.GetOverallSuccess());
+}
+
+TEST(TestRunner, CheckSizes) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.CheckFloatSizes(16, 4));
+ EXPECT_FALSE(runner.CheckFloatSizes(16, 2));
+ EXPECT_EQ(runner.GetErrorMessage(),
+ "Expected '4' elements for a tensor, but only got '2'");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
new file mode 100644
index 0000000000..cf9df2ec26
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+namespace {
+
+// Returns the value in the given position in a tensor.
+template <typename T>
+T Value(const TfLitePtrUnion& data, int index);
+template <>
+float Value(const TfLitePtrUnion& data, int index) {
+ return data.f[index];
+}
+template <>
+uint8_t Value(const TfLitePtrUnion& data, int index) {
+ return data.uint8[index];
+}
+
+template <typename T>
+void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
+ T* input_ptr = reinterpret_cast<T*>(data->raw);
+ for (const T& v : values) {
+ *input_ptr = v;
+ ++input_ptr;
+ }
+}
+
+} // namespace
+
+class TfLiteDriver::Expectation {
+ public:
+ Expectation() { data_.raw = nullptr; }
+ ~Expectation() { delete[] data_.raw; }
+ template <typename T>
+ void SetData(const string& csv_values) {
+ const auto& values = testing::Split<T>(csv_values, ",");
+ data_.raw = new char[values.size() * sizeof(T)];
+ SetTensorData(values, &data_);
+ }
+
+ bool Check(bool verbose, const TfLiteTensor& tensor) {
+ switch (tensor.type) {
+ case kTfLiteFloat32:
+ return TypedCheck<float>(verbose, tensor);
+ case kTfLiteUInt8:
+ return TypedCheck<uint8_t>(verbose, tensor);
+ default:
+ return false;
+ }
+ }
+
+ private:
+ template <typename T>
+ bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
+ int tensor_size = tensor.bytes / sizeof(T);
+
+ bool good_output = true;
+ for (int i = 0; i < tensor_size; ++i) {
+ if (std::abs(Value<T>(data_, i) - Value<T>(tensor.data, i)) > 1e-5) {
+ good_output = false;
+ if (verbose) {
+ std::cerr << " index " << i << ": " << Value<T>(data_, i)
+ << " != " << Value<T>(tensor.data, i) << std::endl;
+ }
+ }
+ }
+ return good_output;
+ }
+
+ TfLitePtrUnion data_;
+};
+
+TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::~TfLiteDriver() {}
+
+void TfLiteDriver::AllocateTensors() {
+ if (must_allocate_tensors_) {
+ if (interpreter_->AllocateTensors() != kTfLiteOk) {
+ std::cerr << "Failed to allocate tensors" << std::endl;
+ abort();
+ }
+ must_allocate_tensors_ = false;
+ }
+}
+
+void TfLiteDriver::LoadModel(const string& bin_file_path) {
+ if (!IsValid()) return;
+ std::cout << std::endl << "Loading model: " << bin_file_path << std::endl;
+
+ model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
+ if (!model_) {
+ Invalidate("Failed to mmap model " + bin_file_path);
+ return;
+ }
+ ops::builtin::BuiltinOpResolver builtins;
+ InterpreterBuilder(*model_, builtins)(&interpreter_);
+ if (!interpreter_) {
+ Invalidate("Failed build interpreter");
+ return;
+ }
+
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::ResetTensor(int id) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ memset(tensor->data.raw, 0, tensor->bytes);
+}
+
+void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ if (interpreter_->ResizeInputTensor(
+ id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
+ Invalidate("Failed to resize input tensor " + std::to_string(id));
+ return;
+ }
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::SetInput(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ switch (tensor->type) {
+ case kTfLiteFloat32: {
+ const auto& values = testing::Split<float>(csv_values, ",");
+ if (!CheckSizes<float>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ case kTfLiteUInt8: {
+ const auto& values = testing::Split<uint8_t>(csv_values, ",");
+ if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ expected_output_[id].reset(new Expectation);
+ switch (tensor->type) {
+ case kTfLiteFloat32:
+ expected_output_[id]->SetData<float>(csv_values);
+ break;
+ case kTfLiteUInt8:
+ expected_output_[id]->SetData<uint8_t>(csv_values);
+ break;
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::Invoke() {
+ if (!IsValid()) return;
+ if (interpreter_->Invoke() != kTfLiteOk) {
+ Invalidate("Failed to invoke interpreter");
+ }
+}
+
+bool TfLiteDriver::CheckResults() {
+ if (!IsValid()) return false;
+ bool success = true;
+ for (const auto& p : expected_output_) {
+ int id = p.first;
+ auto* tensor = interpreter_->tensor(id);
+ if (!p.second->Check(/*verbose=*/false, *tensor)) {
+ // Do not invalidate anything here. Instead, simply output the
+ // differences and return false. Invalidating would prevent all
+ // subsequent invocations from running..
+ std::cerr << "There were errors in invocation '" << GetInvocationId()
+ << "', output tensor '" << id << "':" << std::endl;
+ p.second->Check(/*verbose=*/true, *tensor);
+ success = false;
+ SetOverallSuccess(false);
+ }
+ }
+ expected_output_.clear();
+ return success;
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
new file mode 100644
index 0000000000..4440d4285e
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+
+#include <map>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+namespace tflite {
+namespace testing {
+
+// A test runner that feeds inputs into TF Lite and verifies its outputs.
+class TfLiteDriver : public TestRunner {
+ public:
+ explicit TfLiteDriver(bool use_nnapi);
+ ~TfLiteDriver() override;
+
+ void LoadModel(const string& bin_file_path) override;
+ const std::vector<int>& GetInputs() override {
+ return interpreter_->inputs();
+ }
+ const std::vector<int>& GetOutputs() override {
+ return interpreter_->outputs();
+ }
+ void ReshapeTensor(int id, const string& csv_values) override;
+ void AllocateTensors() override;
+ void ResetTensor(int id) override;
+ void SetInput(int id, const string& csv_values) override;
+ void SetExpectation(int id, const string& csv_values) override;
+ void Invoke() override;
+ bool CheckResults() override;
+
+ private:
+ class Expectation;
+
+ bool use_nnapi_ = false;
+ std::unique_ptr<FlatBufferModel> model_;
+ std::unique_ptr<Interpreter> interpreter_;
+ std::map<int, std::unique_ptr<Expectation>> expected_output_;
+ bool must_allocate_tensors_ = true;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/contrib/lite/testing/tflite_driver_test.cc
new file mode 100644
index 0000000000..79e8a86972
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(TfliteDriverTest, SimpleTest) {
+ std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
+
+ runner->SetModelBaseDir("third_party/tensorflow/contrib/lite");
+ runner->LoadModel("testdata/multi_add.bin");
+ ASSERT_TRUE(runner->IsValid());
+
+ ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3));
+ ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6));
+
+ for (int i : {0, 1, 2, 3}) {
+ runner->ReshapeTensor(i, "1,2,2,1");
+ }
+ ASSERT_TRUE(runner->IsValid());
+
+ runner->AllocateTensors();
+
+ runner->SetInput(0, "0.1,0.2,0.3,0.4");
+ runner->SetInput(1, "0.001,0.002,0.003,0.004");
+ runner->SetInput(2, "0.001,0.002,0.003,0.004");
+ runner->SetInput(3, "0.01,0.02,0.03,0.04");
+
+ runner->ResetTensor(2);
+
+ runner->SetExpectation(5, "0.101,0.202,0.303,0.404");
+ runner->SetExpectation(6, "0.011,0.022,0.033,0.044");
+
+ runner->Invoke();
+ ASSERT_TRUE(runner->IsValid());
+
+ ASSERT_TRUE(runner->CheckResults());
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/contrib/lite/testing/tokenize.cc
new file mode 100644
index 0000000000..2e84ea475c
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize.cc
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+#include <istream>
+#include <string>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+void Tokenize(std::istream* input, TokenProcessor* processor) {
+ enum State { kBuildQuotedToken, kBuildToken, kIdle };
+
+ std::string current_token;
+ State state = kIdle;
+ auto start_token = [&](char c) {
+ state = kBuildToken;
+ current_token.clear();
+ current_token = c;
+ };
+ auto issue_token = [&]() {
+ state = kIdle;
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto start_quoted_token = [&]() {
+ state = kBuildQuotedToken;
+ current_token.clear();
+ };
+ auto issue_quoted_token = [&]() {
+ state = kIdle;
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto issue_delim = [&](char d) {
+ current_token = string(1, d);
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto is_delim = [](char c) { return c == '{' || c == '}' || c == ':'; };
+ auto is_quote = [](char c) { return c == '"'; };
+
+ for (auto it = std::istreambuf_iterator<char>(*input);
+ it != std::istreambuf_iterator<char>(); ++it) {
+ switch (state) {
+ case kIdle:
+ if (is_delim(*it)) {
+ issue_delim(*it);
+ } else if (is_quote(*it)) {
+ start_quoted_token();
+ } else if (!isspace(*it)) {
+ start_token(*it);
+ }
+ break;
+ case kBuildToken:
+ if (is_delim(*it)) {
+ issue_token();
+ issue_delim(*it);
+ } else if (is_quote(*it)) {
+ issue_token();
+ start_quoted_token();
+ } else if (isspace(*it)) {
+ issue_token();
+ } else {
+ current_token += *it;
+ }
+ break;
+ case kBuildQuotedToken:
+ if (is_quote(*it)) {
+ issue_quoted_token();
+ } else {
+ current_token += *it;
+ }
+ break;
+ }
+ }
+ if (state != kIdle) {
+ issue_token();
+ }
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h
new file mode 100644
index 0000000000..daccf0e84a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+
+#include <istream>
+#include <string>
+
+namespace tflite {
+namespace testing {
+
+// Process tokens coming from Tokenize().
+class TokenProcessor {
+ public:
+ virtual ~TokenProcessor() {}
+ // Process a single token. The token won't be reused, so it is OK to call
+ // token.swap().
+ virtual void ConsumeToken(std::string* token) = 0;
+};
+
+// Tokenize a stream on whitespaces, colons and curly braces. Whitespaces are
+// removed from the tokens and double-quotes can be used to avoid that. Note
+// that there is no way to escape double-quotes, so there's no way to have a
+// double-quote inside a token.
+void Tokenize(std::istream* input, TokenProcessor* processor);
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/contrib/lite/testing/tokenize_test.cc
new file mode 100644
index 0000000000..80f44aacca
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class TokenCollector : public TokenProcessor {
+ public:
+ void ConsumeToken(std::string* token) override { tokens_.push_back(*token); }
+ const std::vector<std::string>& Tokens() { return tokens_; }
+
+ private:
+ std::vector<std::string> tokens_;
+};
+
+std::vector<std::string> TokenizeString(const std::string& s) {
+ std::stringstream ss(s);
+ TokenCollector collector;
+ Tokenize(&ss, &collector);
+ return collector.Tokens();
+}
+
+TEST(TokenizeTest, TokenDetection) {
+ EXPECT_THAT(TokenizeString("x :1"), ElementsAre("x", ":", "1"));
+ EXPECT_THAT(TokenizeString("x:1"), ElementsAre("x", ":", "1"));
+ EXPECT_THAT(TokenizeString("x {1"), ElementsAre("x", "{", "1"));
+ EXPECT_THAT(TokenizeString("x{1"), ElementsAre("x", "{", "1"));
+ EXPECT_THAT(TokenizeString("x }1"), ElementsAre("x", "}", "1"));
+ EXPECT_THAT(TokenizeString("x}1"), ElementsAre("x", "}", "1"));
+ EXPECT_THAT(TokenizeString("x \"1"), ElementsAre("x", "1"));
+ EXPECT_THAT(TokenizeString("x\"1"), ElementsAre("x", "1"));
+}
+
+TEST(TokenizeTest, QuotedTokenDetection) {
+ EXPECT_THAT(TokenizeString("\"w:x{y}z\"1"), ElementsAre("w:x{y}z", "1"));
+ EXPECT_THAT(TokenizeString("\"w:x{y}z\"\"1\""), ElementsAre("w:x{y}z", "1"));
+}
+
+TEST(TokenizeTest, Delimiters) {
+ EXPECT_THAT(TokenizeString("}"), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString("}}"), ElementsAre("}", "}"));
+ EXPECT_THAT(TokenizeString("{"), ElementsAre("{"));
+ EXPECT_THAT(TokenizeString("{{"), ElementsAre("{", "{"));
+ EXPECT_THAT(TokenizeString(":"), ElementsAre(":"));
+ EXPECT_THAT(TokenizeString("::"), ElementsAre(":", ":"));
+}
+
+TEST(TokenizeTest, CornerCases) {
+ EXPECT_THAT(TokenizeString(" i { b:a } "),
+ ElementsAre("i", "{", "b", ":", "a", "}"));
+ EXPECT_THAT(TokenizeString(" }"), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString(" } "), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString(" {} "), ElementsAre("{", "}"));
+ EXPECT_THAT(TokenizeString(" x{} y{} "),
+ ElementsAre("x", "{", "}", "y", "{", "}"));
+ EXPECT_THAT(TokenizeString("x:1 y:2 "),
+ ElementsAre("x", ":", "1", "y", ":", "2"));
+ EXPECT_THAT(TokenizeString("x:\"1\" y:2 "),
+ ElementsAre("x", ":", "1", "y", ":", "2"));
+ EXPECT_THAT(TokenizeString("x:\"1, 2\" y:\"\" "),
+ ElementsAre("x", ":", "1, 2", "y", ":", ""));
+}
+
+TEST(TokenizeTest, NewLines) {
+ EXPECT_THAT(TokenizeString("x:\n1,\n 2 \n y :\n3 \n"),
+ ElementsAre("x", ":", "1,", "2", "y", ":", "3"));
+}
+
+TEST(TokenizeTest, LongString) {
+ EXPECT_THAT(
+ TokenizeString(" i { b:a } input {"
+ "a: \"1e-1, 2,3\" b:\"1,2,3\"\n c{ "
+ "id:1 x{d{a:"
+ "1}}} f:2 "
+ "\n}\n t:1"),
+ ElementsAreArray({"i", "{", "b", ":", "a", "}", "input", "{",
+ "a", ":", "1e-1, 2,3", "b", ":", "1,2,3", "c", "{",
+ "id", ":", "1", "x", "{", "d", "{", "a",
+ ":", "1", "}", "}", "}", "f", ":", "2",
+ "}", "t", ":", "1"}));
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
new file mode 100644
index 0000000000..1c73ab8f4a
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -0,0 +1,350 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_cc",
+ "tf_proto_library_py",
+)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_binary",
+ "tf_cc_test",
+)
+
+tf_proto_library_cc(
+ name = "toco_flags_proto",
+ srcs = ["toco_flags.proto"],
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_cc(
+ name = "model_flags_proto",
+ srcs = ["model_flags.proto"],
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_py(
+ name = "toco_flags_proto",
+ srcs = [
+ "toco_flags.proto",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_py(
+ name = "model_flags_proto",
+ srcs = [
+ "model_flags.proto",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "tensorflow_core_cc_protos_all",
+ deps = ["//tensorflow/core:protos_all_cc"],
+)
+
+cc_library(
+ name = "runtime",
+ hdrs = [
+ "runtime/common.h",
+ "runtime/types.h",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:types",
+ ],
+)
+
+# :model offers the core data structures representing a model (a.k.a. "graph")
+# for tooling purposes (not needed at inference runtime).
+# That includes the top-level Model structure, and the lower-level Operator,
+# Array, Buffer structures, etc.
+cc_library(
+ name = "model",
+ hdrs = [
+ "model.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_port",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/base:core_headers",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
+ name = "toco_graphviz_dump_options",
+ srcs = [
+ "toco_graphviz_dump_options.cc",
+ ],
+ hdrs = [
+ "toco_graphviz_dump_options.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "toco_cmdline_flags",
+ srcs = [
+ "toco_cmdline_flags.cc",
+ ],
+ hdrs = [
+ "toco_cmdline_flags.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_cmdline_flags",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "model_cmdline_flags",
+ srcs = [
+ "model_cmdline_flags.cc",
+ ],
+ hdrs = [
+ "args.h",
+ "model_cmdline_flags.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_flags_proto_cc",
+ ":toco_graphviz_dump_options",
+ ":toco_port",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "toco_port",
+ srcs = [
+ "toco_port.cc",
+ ],
+ hdrs = [
+ "format_port.h",
+ "toco_port.h",
+ "toco_types.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ] + select({
+ "//tensorflow:android": [],
+ "//tensorflow:darwin": [],
+ "//tensorflow:ios": [],
+ "//conditions:default": [],
+ "//tensorflow:dummy_disabled_internal": [],
+ }),
+)
+
+cc_library(
+ name = "graph_transformations",
+ srcs = [
+ "graph_transformations/convert_pure_conv_to_depthwise.cc",
+ "graph_transformations/create_im2col_arrays.cc",
+ "graph_transformations/dequantize.cc",
+ "graph_transformations/drop_fake_quant.cc",
+ "graph_transformations/drop_im2col_arrays.cc",
+ "graph_transformations/ensure_bias_vectors.cc",
+ "graph_transformations/fuse_activation_functions.cc",
+ "graph_transformations/fuse_binary_into_following_affine.cc",
+ "graph_transformations/fuse_binary_into_preceding_affine.cc",
+ "graph_transformations/graph_transformations.cc",
+ "graph_transformations/hardcode_min_max.cc",
+ "graph_transformations/identify_l2_normalization.cc",
+ "graph_transformations/identify_l2_pool.cc",
+ "graph_transformations/identify_lstm.cc",
+ "graph_transformations/identify_relu1.cc",
+ "graph_transformations/make_initial_dequantize_operator.cc",
+ "graph_transformations/propagate_array_data_types.cc",
+ "graph_transformations/propagate_fixed_sizes.cc",
+ "graph_transformations/quantize.cc",
+ "graph_transformations/read_fake_quant_min_max.cc",
+ "graph_transformations/remove_final_dequantize_op.cc",
+ "graph_transformations/remove_tensorflow_assert.cc",
+ "graph_transformations/remove_tensorflow_identity.cc",
+ "graph_transformations/remove_trivial_binary.cc",
+ "graph_transformations/remove_trivial_concatenation.cc",
+ "graph_transformations/remove_trivial_concatenation_input.cc",
+ "graph_transformations/remove_trivial_passthrough.cc",
+ "graph_transformations/remove_trivial_passthrough.h",
+ "graph_transformations/remove_trivial_quantized_activation_func.cc",
+ "graph_transformations/remove_trivial_reshape.cc",
+ "graph_transformations/remove_unused_op.cc",
+ "graph_transformations/resolve_batch_normalization.cc",
+ "graph_transformations/resolve_constant_binary.cc",
+ "graph_transformations/resolve_constant_concatenation.cc",
+ "graph_transformations/resolve_constant_fake_quant.cc",
+ "graph_transformations/resolve_constant_tensorflow_shape.cc",
+ "graph_transformations/resolve_constant_unary.cc",
+ "graph_transformations/resolve_mean_attributes.cc",
+ "graph_transformations/resolve_pad_attributes.cc",
+ "graph_transformations/resolve_reorder_axes.cc",
+ "graph_transformations/resolve_reshape_attributes.cc",
+ "graph_transformations/resolve_slice_attributes.cc",
+ "graph_transformations/resolve_strided_slice_attributes.cc",
+ "graph_transformations/resolve_tensorflow_concat.cc",
+ "graph_transformations/resolve_tensorflow_matmul.cc",
+ "graph_transformations/resolve_tensorflow_merge.cc",
+ "graph_transformations/resolve_tensorflow_squeeze.cc",
+ "graph_transformations/resolve_tensorflow_switch.cc",
+ "graph_transformations/resolve_tensorflow_tile.cc",
+ "graph_transformations/unfuse_activation_functions.cc",
+ ],
+ hdrs = [
+ "graph_transformations/graph_transformations.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_port",
+ ":tooling_util",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+# :toco_tooling is the library providing the offline tooling functionality
+# exposed by the :toco command-line tool.
+cc_library(
+ name = "toco_tooling",
+ srcs = [
+ "allocate_transient_arrays.cc",
+ "export_tensorflow.cc",
+ "import_tensorflow.cc",
+ "tensorflow_util.cc",
+ "toco_tooling.cc",
+ ],
+ hdrs = [
+ "allocate_transient_arrays.h",
+ "export_tensorflow.h",
+ "import_tensorflow.h",
+ "tensorflow_util.h",
+ "toco_tooling.h",
+ ],
+ copts = select({
+ "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"],
+ "//conditions:default": [],
+ }),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_transformations",
+ ":model_flags_proto_cc",
+ ":toco_flags_proto_cc",
+ ":model",
+ ":runtime",
+ "//tensorflow/core:protos_all_cc",
+ ":toco_port",
+ ":tooling_util",
+ ":toco_graphviz_dump_options",
+ "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/contrib/lite/toco/tflite:export",
+ "//tensorflow/contrib/lite/toco/tflite:import",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/core:lib",
+ "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:resolve_cluster",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ ] + select({
+ # Placeholder for internal darwin rule.
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "tooling_util",
+ srcs = [
+ "dump_graphviz.cc",
+ "tooling_util.cc",
+ ],
+ hdrs = [
+ "dump_graphviz.h",
+ "tooling_util.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_flags_proto_cc",
+ ":runtime",
+ ":toco_flags_proto_cc",
+ ":toco_graphviz_dump_options",
+ ":toco_port",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+tf_cc_test(
+ name = "tooling_util_test",
+ srcs = ["tooling_util_test.cc"],
+ deps = [
+ ":model",
+ ":tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+# :toco is the main public command-line tool exposing the functionality
+# of the :toco_tooling library.
+tf_cc_binary(
+ name = "toco",
+ srcs = ["toco.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_cmdline_flags",
+ ":model_flags_proto_cc",
+ ":toco_cmdline_flags",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ ":toco_tooling",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "toco_port_test",
+ srcs = ["toco_port_test.cc"],
+ data = [
+ "toco_port_test.cc",
+ ],
+ deps = [
+ ":toco_port",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
new file mode 100644
index 0000000000..2f4454d7c8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -0,0 +1,318 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace {
+
+// The life span of an array.
+struct ArrayLifespan {
+ // If true, the array is persistent state (as in a RNN). In that case,
+ // its allocation is permanent and the first_op, last_op members are
+ // unused. (The term 'transient' is a misnomer and we should think in
+ // terms of 'workspace' instead).
+ bool persistent = false;
+ // Index of the first op addressing that array. The array must be allocated
+ // just before executing this op.
+ std::size_t first_op = 0;
+ // Index of the last op addressing that array. We want to deallocate the array
+ // immediately after executing this op.
+ std::size_t last_op = 0;
+};
+
+bool StartsAt(const ArrayLifespan& lifespan, std::size_t op_index) {
+ return !lifespan.persistent && lifespan.first_op == op_index;
+}
+
+bool EndsAt(const ArrayLifespan& lifespan, std::size_t op_index) {
+ return !lifespan.persistent && lifespan.last_op == op_index;
+}
+
+// Helper function for ComputeArrayLifespans: updates one ArrayLifespan for
+// one array for one op.
+void UpdateArrayLifespan(
+ const string& array_name, std::size_t op_index,
+ std::unordered_map<string, ArrayLifespan>* array_lifespans) {
+ if (array_lifespans->count(array_name)) {
+ auto& lifespan = array_lifespans->at(array_name);
+ if (!lifespan.persistent) {
+ lifespan.first_op = std::min(lifespan.first_op, op_index);
+ lifespan.last_op = std::max(lifespan.last_op, op_index);
+ }
+ } else {
+ ArrayLifespan lifespan;
+ lifespan.first_op = op_index;
+ lifespan.last_op = op_index;
+ (*array_lifespans)[array_name] = lifespan;
+ }
+}
+
+// Computes the ArrayLifespan for each array.
+void ComputeArrayLifespans(
+ const Model& model,
+ std::unordered_map<string, ArrayLifespan>* array_lifespans) {
+ CHECK(array_lifespans->empty());
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ ArrayLifespan lifespan;
+ lifespan.persistent = true;
+ (*array_lifespans)[rnn_state.state_array()] = lifespan;
+ }
+ for (std::size_t op_index = 0; op_index < model.operators.size();
+ op_index++) {
+ const auto& op = model.operators[op_index];
+ for (const auto& input : op->inputs) {
+ UpdateArrayLifespan(input, op_index, array_lifespans);
+ }
+ for (const auto& output : op->outputs) {
+ UpdateArrayLifespan(output, op_index, array_lifespans);
+ }
+ }
+}
+
+inline bool operator==(const Alloc& a, const Alloc& b) {
+ CHECK(a.start != b.start || a.end == b.end);
+ return a.start == b.start;
+}
+
+// Helper to keep track of total allocation size and of currently live
+// allocations, and containing the core allocation routine.
+class Allocator {
+ public:
+ Allocator() : total_size_(0) {}
+
+ // Core allocation routine.
+ void Allocate(std::size_t size, Alloc* result) {
+ // Naive algorithm: pick the first gap between live allocations,
+ // that is wide enough for the new array.
+ std::size_t pos = 0;
+ for (const auto& a : live_allocs_) {
+ if (a.start >= pos + size) {
+ result->start = pos;
+ result->end = pos + size;
+ live_allocs_.insert(*result);
+ return;
+ }
+ pos = a.end;
+ }
+ // No sufficiently wide gap was found before an existing live allocation,
+ // so we allocate the new array at the end of the allocation space.
+ // We may then have to grow total_size_.
+ total_size_ = std::max(total_size_, pos + size);
+ result->start = pos;
+ result->end = pos + size;
+ live_allocs_.insert(*result);
+ }
+
+ void Deallocate(const Alloc& a) {
+ auto iter = std::lower_bound(live_allocs_.begin(), live_allocs_.end(), a);
+ CHECK(iter != live_allocs_.end());
+ CHECK(*iter == a);
+ live_allocs_.erase(iter);
+ }
+
+ std::size_t total_size() const { return total_size_; }
+
+ private:
+ std::size_t total_size_;
+ std::set<Alloc> live_allocs_;
+};
+
+// Returns the required transient allocation size (in bytes) for a given array,
+// or 0 if it's not a transient array.
+std::size_t TransientArraySize(const Model& model, const string& array_name,
+ std::size_t transient_data_alignment) {
+ if (!IsAllocatableTransientArray(model, array_name)) {
+ return 0;
+ }
+ const auto& array = model.arrays.at(array_name);
+ CHECK(array->has_shape())
+ << "Array '" << array_name << "' doesn't have a shape";
+ if (array->data_type == ArrayDataType::kNone) {
+ // Catch a typical issue at the moment with RNN states
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (rnn_state.state_array() == array_name) {
+ LOG(FATAL)
+ << "A RNN state array, " << array_name << ", still does not "
+ << "have a known data type after all graph transformations have "
+ << "run. That's mostly a toco bug --- sorry. For now, you can "
+ << "work around this issue by adding manually_create:true in the "
+ << "--rnn_state description of this RNN state.";
+ }
+ }
+ LOG(FATAL) << "An array, " << array_name << ", still does not "
+ << "have a known data type after all graph transformations have "
+ << "run.";
+ }
+ const std::size_t elem_size = ElementSize(array->data_type);
+ const std::size_t raw_size =
+ elem_size * RequiredBufferSizeForShape(array->shape());
+ const std::size_t rounded_size =
+ RoundUpToNextMultipleOf(raw_size, transient_data_alignment);
+ return rounded_size;
+}
+
+// Allocates an array: call this for every array just before the first
+// op where it is used.
+void AllocateTransientArray(const Model& model, const string& array_name,
+ Allocator* allocator,
+ std::size_t transient_data_alignment) {
+ if (!IsAllocatableTransientArray(model, array_name)) {
+ return;
+ }
+ const std::size_t size =
+ TransientArraySize(model, array_name, transient_data_alignment);
+ const auto& array = model.arrays.at(array_name);
+ CHECK(!array->alloc);
+ allocator->Allocate(size, &array->GetOrCreateAlloc());
+}
+
+// Deallocates an array: call this for every array just after the last
+// op where it is used.
+void DeallocateTransientArray(const Model& model, const string& array_name,
+ Allocator* allocator) {
+ if (!IsAllocatableTransientArray(model, array_name)) {
+ return;
+ }
+ const auto& array = model.arrays.at(array_name);
+ CHECK(!!array->alloc);
+ allocator->Deallocate(*array->alloc);
+}
+
+} // namespace
+
+void AllocateTransientArrays(Model* model,
+ std::size_t transient_data_alignment) {
+ // Precompute the lifespans for all arrays.
+ std::unordered_map<string, ArrayLifespan> array_lifespans;
+ ComputeArrayLifespans(*model, &array_lifespans);
+
+ // In case of variable batch, our convention will be to compute the
+ // allocations for batch==1, then let the inference code multiply all
+ // the offsets by the actual runtime batch size. Conveniently,
+ // the variable_batch and batch flags are mutually exclusive, and the default
+ // value of batch is 1, so we have nothing special to do here. Let us
+ // just guard this assumption with a CHECK:
+ bool batchless_input_shapes = true;
+ for (const auto& input_array : model->flags.input_arrays()) {
+ if (input_array.shape().empty() || input_array.shape(0) != 1) {
+ batchless_input_shapes = false;
+ break;
+ }
+ }
+ CHECK(!model->flags.variable_batch() || batchless_input_shapes);
+
+ Allocator allocator;
+
+ // Construct a sorted map of array names, so that other layout engines can
+ // match exactly.
+ std::map<string, const Array*> ordered_arrays_map;
+ for (const auto& pair : model->arrays) {
+ ordered_arrays_map[pair.first] = pair.second.get();
+ }
+
+ // Allocate persistent arrays (like RNN states). For them, 'transient'
+ // is a misnormer, should read 'workspace'.
+ for (const auto& array_pair : ordered_arrays_map) {
+ const string& array_name = array_pair.first;
+ const auto& array_lifespan = array_lifespans.find(array_name)->second;
+ if (array_lifespan.persistent) {
+ AllocateTransientArray(*model, array_name, &allocator,
+ transient_data_alignment);
+ }
+ }
+
+ for (std::size_t op_index = 0; op_index < model->operators.size();
+ op_index++) {
+ const auto& op = model->operators[op_index];
+ // Allocate those arrays whose lifespan starts exactly here.
+ for (const auto& input : op->inputs) {
+ if (StartsAt(array_lifespans[input], op_index)) {
+ AllocateTransientArray(*model, input, &allocator,
+ transient_data_alignment);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (StartsAt(array_lifespans[output], op_index)) {
+ AllocateTransientArray(*model, output, &allocator,
+ transient_data_alignment);
+ }
+ }
+ // Deallocate those arrays whose lifespan ends exactly here.
+ for (const auto& input : op->inputs) {
+ if (EndsAt(array_lifespans[input], op_index)) {
+ DeallocateTransientArray(*model, input, &allocator);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (EndsAt(array_lifespans[output], op_index)) {
+ DeallocateTransientArray(*model, output, &allocator);
+ }
+ }
+ }
+
+ // Just out of curiosity (not used in the actual allocation process)
+ // evaluate the optimal total allocated size.
+ // First, compute the size of persistent arrays.
+ std::size_t optimal_transient_alloc_size = 0;
+ std::size_t persistent_alloc_size = 0;
+ for (const auto& array_pair : ordered_arrays_map) {
+ const string& array_name = array_pair.first;
+ const auto& array_lifespan = array_lifespans.find(array_name)->second;
+ if (array_lifespan.persistent) {
+ persistent_alloc_size +=
+ TransientArraySize(*model, array_name, transient_data_alignment);
+ }
+ }
+ for (const auto& op : model->operators) {
+ // for each operator, compute the sum of the sizes of the array that must
+ // be live during the execution of this operator, plus the size of
+ // persistent arrays that must be live at all times.
+ std::size_t size = persistent_alloc_size;
+ for (const auto& input : op->inputs) {
+ if (!array_lifespans[input].persistent) {
+ size += TransientArraySize(*model, input, transient_data_alignment);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (!array_lifespans[output].persistent) {
+ size += TransientArraySize(*model, output, transient_data_alignment);
+ }
+ }
+ // The optimal total size is the maximum of all operator-specific sizes.
+ optimal_transient_alloc_size = std::max(optimal_transient_alloc_size, size);
+ }
+
+ model->transient_data_size = allocator.total_size();
+ model->transient_data_alignment = transient_data_alignment;
+ CHECK_GE(model->transient_data_size, optimal_transient_alloc_size);
+ LOG(INFO) << "Total transient array allocated size: "
+ << model->transient_data_size << " bytes, "
+ << "theoretical optimal value: " << optimal_transient_alloc_size
+ << " bytes.";
+ CheckInvariants(*model);
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
new file mode 100644
index 0000000000..12d0d0498f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+// We align the allocated sizes to the next multiple of a cache line,
+// to get simple performance characteristics without side effects of
+// accesses to one buffer on accesses to another buffer.
+// That also takes care of data type alignment for any reasonable type
+// (no reasonable data type should have alignment greater than a cache line).
+// Here we make CPU-centric assumptions, in particular, we assume 64-byte cache
+// lines. Getting this wrong by a factor of 2x (if this ever changes) wouldn't
+// be terrible.
+// Embedded architectures may use a different value for alignment.
+constexpr std::size_t kDefaultTransientDataAlignment = 64;
+
+// Rounds up dividend to a value divisible by divisor.
+inline std::size_t RoundUpToNextMultipleOf(std::size_t dividend,
+ std::size_t divisor) {
+ return ((dividend + divisor - 1) / divisor) * divisor;
+}
+
+void AllocateTransientArrays(Model* model,
+ std::size_t transient_data_alignment);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
new file mode 100644
index 0000000000..28661d4ff0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -0,0 +1,225 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This abstracts command line arguments in toco.
+// Arg<T> is a parseable type that can register a default value, be able to
+// parse itself, and keep track of whether it was specified.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+
+#include <functional>
+#include <unordered_map>
+#include <vector>
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+namespace toco {
+
+// Since std::vector<int32> is in the std namespace, and we are not allowed
+// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
+// to use as the flag type:
+struct IntList {
+ std::vector<int32> elements;
+};
+struct StringMapList {
+ std::vector<std::unordered_map<string, string>> elements;
+};
+
+// command_line_flags.h don't track whether or not a flag is specified. Arg
+// contains the value (which will be default if not specified) and also
+// whether the flag is specified.
+// TODO(aselle): consider putting doc string and ability to construct the
+// tensorflow argument into this, so declaration of parameters can be less
+// distributed.
+// Every template specialization of Arg is required to implement
+// default_value(), specified(), value(), parse(), bind().
+template <class T>
+class Arg final {
+ public:
+ explicit Arg(T default_ = T()) : value_(default_) {}
+ virtual ~Arg() {}
+
+ // Provide default_value() to arg list
+ T default_value() const { return value_; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Const reference to parsed value.
+ const T& value() const { return value_; }
+
+ // Parsing callback for the tensorflow::Flags code
+ bool parse(T value_in) {
+ value_ = value_in;
+ specified_ = true;
+ return true;
+ }
+
+ // Bind the parse member function so tensorflow::Flags can call it.
+ std::function<bool(T)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ private:
+ // Becomes true after parsing if the value was specified
+ bool specified_ = false;
+ // Value of the argument (initialized to the default in the constructor).
+ T value_;
+};
+
+template <>
+class Arg<toco::IntList> final {
+ public:
+ // Provide default_value() to arg list
+ string default_value() const { return ""; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Bind the parse member function so tensorflow::Flags can call it.
+ bool parse(string text) {
+ parsed_value_.elements.clear();
+ specified_ = true;
+ // strings::Split("") produces {""}, but we need {} on empty input.
+ // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
+ // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
+ if (!text.empty()) {
+ int32 element;
+ for (absl::string_view part : absl::StrSplit(text, ',')) {
+ if (!SimpleAtoi(part, &element)) return false;
+ parsed_value_.elements.push_back(element);
+ }
+ }
+ return true;
+ }
+
+ std::function<bool(string)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ const toco::IntList& value() const { return parsed_value_; }
+
+ private:
+ toco::IntList parsed_value_;
+ bool specified_ = false;
+};
+
+template <>
+class Arg<toco::StringMapList> final {
+ public:
+ // Provide default_value() to StringMapList
+ string default_value() const { return ""; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Bind the parse member function so tensorflow::Flags can call it.
+
+ bool parse(string text) {
+ parsed_value_.elements.clear();
+ specified_ = true;
+
+ if (text.empty()) {
+ return true;
+ }
+
+#if defined(PLATFORM_GOOGLE)
+ std::vector<absl::string_view> outer_vector;
+ absl::string_view text_disposable_copy = text;
+ SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector);
+ for (const absl::string_view& outer_member_stringpiece : outer_vector) {
+ string outer_member(outer_member_stringpiece);
+ if (outer_member.empty()) {
+ continue;
+ }
+ string outer_member_copy = outer_member;
+ absl::StripAsciiWhitespace(&outer_member);
+ if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
+ if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
+ const std::vector<string> inner_fields_vector =
+ strings::Split(outer_member, ',');
+
+ std::unordered_map<string, string> element;
+ for (const string& member_field : inner_fields_vector) {
+ std::vector<string> outer_member_key_value =
+ strings::Split(member_field, ':');
+ if (outer_member_key_value.size() != 2) return false;
+ string& key = outer_member_key_value[0];
+ string& value = outer_member_key_value[1];
+ absl::StripAsciiWhitespace(&key);
+ absl::StripAsciiWhitespace(&value);
+ if (element.count(key) != 0) return false;
+ element[key] = value;
+ }
+ parsed_value_.elements.push_back(element);
+ }
+ return true;
+#else
+ // TODO(aselle): Fix argument parsing when absl supports structuredline
+ fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__,
+ __LINE__);
+ abort();
+#endif
+ }
+
+ std::function<bool(string)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ const toco::StringMapList& value() const { return parsed_value_; }
+
+ private:
+ toco::StringMapList parsed_value_;
+ bool specified_ = false;
+};
+
+// Flags that describe a model. See model_cmdline_flags.cc for details.
+struct ParsedModelFlags {
+ Arg<string> input_array;
+ Arg<string> input_arrays;
+ Arg<string> output_array;
+ Arg<string> output_arrays;
+ Arg<string> input_shapes;
+ Arg<float> mean_value = Arg<float>(0.f);
+ Arg<string> mean_values;
+ Arg<float> std_value = Arg<float>(1.f);
+ Arg<string> std_values;
+ Arg<bool> variable_batch = Arg<bool>(false);
+ Arg<bool> drop_control_dependency = Arg<bool>(false);
+ Arg<toco::IntList> input_shape;
+ Arg<toco::StringMapList> rnn_states;
+ Arg<toco::StringMapList> model_checks;
+ // Debugging output options
+ Arg<string> graphviz_first_array;
+ Arg<string> graphviz_last_array;
+ Arg<string> dump_graphviz;
+ Arg<bool> dump_graphviz_video = Arg<bool>(false);
+};
+
+// Flags that describe the operation you would like to do (what conversion
+// you want). See toco_cmdline_flags.cc for details.
+struct ParsedTocoFlags {
+ Arg<string> input_file;
+ Arg<string> output_file;
+ Arg<string> input_format;
+ Arg<string> output_format;
+ // TODO(aselle): command_line_flags doesn't support doubles
+ Arg<float> default_ranges_min = Arg<float>(0.);
+ Arg<float> default_ranges_max = Arg<float>(0.);
+ Arg<string> input_type;
+ Arg<string> input_types;
+ Arg<string> inference_type;
+ Arg<bool> drop_fake_quant = Arg<bool>(false);
+ Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
+ Arg<bool> allow_custom_ops = Arg<bool>(false);
+};
+
+} // namespace toco
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
new file mode 100644
index 0000000000..f5e2868dc0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -0,0 +1,293 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+
+#include <memory>
+#include <set>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+using toco::port::AppendF;
+using toco::port::StringF;
+
+namespace toco {
+namespace {
+
+class Color {
+ public:
+ Color() {}
+ Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
+ // Returns the string serialization of this color in graphviz format,
+ // for use as 'fillcolor' in boxes.
+ string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); }
+ // Returns the serialization in graphviz format of a suitable color to use
+ // 'fontcolor' in the same boxes. It should black or white, whichever offers
+ // the better contrast from FillColorString().
+ string TextColorString() const {
+ // https://en.wikipedia.org/wiki/Relative_luminance
+ const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
+ const uint8 l = luminance > 128.f ? 0 : 255;
+ return StringF("%.2X%.2X%.2X", l, l, l);
+ }
+
+ private:
+ uint8 r_ = 0, g_ = 0, b_ = 0;
+};
+
+struct NodeProperties {
+ // The text to display inside the box for this node.
+ string label;
+ // The color to use for this node; will be used as 'fillcolor'
+ // for its box. See Color::FillColorString. A suitable, different
+ // color will be chosen for the 'fontcolor' for the inside text
+ // label, see Color::TextColorString.
+ Color color;
+};
+
+// All colors in this file are from:
+// https://material.io/guidelines/style/color.html
+
+Color GetColorForArray(const Model& model, const string& array_name) {
+ // Arrays involved in RNN back-edges have a different color
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ // RNN state, fed by a back-edge. Bold color.
+ if (array_name == rnn_state.state_array()) {
+ return Color(0x0F, 0x9D, 0x58);
+ }
+ // RNN back-edge source, feeding a RNN state.
+ // Light tone of the same color as RNN states.
+ if (array_name == rnn_state.back_edge_source_array()) {
+ return Color(0xB7, 0xE1, 0xCD);
+ }
+ }
+ // Constant parameter arrays have their own bold color
+ if (model.GetArray(array_name).buffer) {
+ return Color(0x42, 0x85, 0xF4);
+ }
+ // Remaining arrays are activations.
+ // We use gray colors for them because they are the majority
+ // of arrays so we want to highlight other arrays instead of them.
+ // First, we use a bolder gray for input/output arrays:
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+ if (IsInputArray(model, array_name) ||
+ array_name == dump_options.graphviz_first_array ||
+ array_name == dump_options.graphviz_last_array) {
+ return Color(0x9E, 0x9E, 0x9E);
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
+ return Color(0x9E, 0x9E, 0x9E);
+ }
+ }
+ // Remaining arrays are intermediate activation arrays.
+ // Lighter tone of the same grey as for input/output arrays:
+ // We want these to be very discrete.
+ return Color(0xF5, 0xF5, 0xF5);
+}
+
+NodeProperties GetPropertiesForArray(const Model& model,
+ const string& array_name) {
+ NodeProperties node_properties;
+ node_properties.color = GetColorForArray(model, array_name);
+ node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}});
+
+ // Append array shape to the label.
+ auto& array = model.GetArray(array_name);
+
+ if (array.data_type == ArrayDataType::kFloat) {
+ AppendF(&node_properties.label, "\\nType: float");
+ } else if (array.data_type == ArrayDataType::kInt32) {
+ AppendF(&node_properties.label, "\\nType: int32");
+ } else if (array.data_type == ArrayDataType::kUint8) {
+ AppendF(&node_properties.label, "\\nType: uint8");
+ }
+
+ if (array.has_shape()) {
+ auto& array_shape = array.shape();
+ node_properties.label += "\\n[";
+ for (int id = 0; id < array_shape.dimensions_count(); id++) {
+ if (id == 0) {
+ AppendF(&node_properties.label, "%d", array_shape.dims(id));
+ } else {
+ AppendF(&node_properties.label, "x%d", array_shape.dims(id));
+ }
+ }
+ node_properties.label += "]";
+ }
+
+ if (array.minmax) {
+ AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]",
+ array.minmax->min, array.minmax->max);
+ }
+
+ if (array.quantization_params) {
+ AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)",
+ array.quantization_params->scale,
+ array.quantization_params->zero_point);
+ }
+
+ if (array.alloc) {
+ AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)",
+ array.alloc->start, array.alloc->end);
+ }
+
+ return node_properties;
+}
+
+NodeProperties GetPropertiesForOperator(const Operator& op) {
+ NodeProperties node_properties;
+ if (op.type == OperatorType::kTensorFlowUnsupported) {
+ node_properties.label =
+ static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
+ } else {
+ node_properties.label = OperatorTypeName(op.type);
+ }
+ // Additional information for some of the operators.
+ switch (op.type) {
+ case OperatorType::kConv: {
+ const auto& conv_op = static_cast<const ConvOperator&>(op);
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
+ conv_op.stride_height,
+ conv_op.padding.type == PaddingType::kSame ? "S" : "V");
+ break;
+ }
+ case OperatorType::kDepthwiseConv: {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
+ conv_op.stride_height,
+ conv_op.padding.type == PaddingType::kSame ? "S" : "V");
+ break;
+ }
+ case OperatorType::kFullyConnected: {
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ break;
+ }
+ default:
+ node_properties.color = Color(0xDB, 0x44, 0x37);
+ break;
+ }
+
+ return node_properties;
+}
+
+std::vector<const Operator*> OperatorsToDump(const Model& model) {
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+ bool first_specified = !dump_options.graphviz_first_array.empty();
+ bool last_specified = !dump_options.graphviz_last_array.empty();
+ CHECK_EQ(first_specified, last_specified);
+ std::vector<const Operator*> ops_to_dump;
+ if (last_specified) {
+ // Return only the part of the graph between graphviz_first_array
+ // and graphviz_last_array.
+ CHECK(model.arrays.count(dump_options.graphviz_first_array));
+ CHECK(model.arrays.count(dump_options.graphviz_last_array));
+ std::unordered_set<string> arrays_already_produced;
+ std::vector<string> arrays_to_produce;
+ arrays_to_produce.push_back(dump_options.graphviz_last_array);
+ while (!arrays_to_produce.empty()) {
+ const string array = arrays_to_produce.back();
+ arrays_to_produce.pop_back();
+ CHECK(!arrays_already_produced.count(array));
+ arrays_already_produced.insert(array);
+ const Operator* op = GetOpWithOutput(model, array);
+ if (!op) {
+ continue;
+ }
+ ops_to_dump.push_back(op);
+ for (const string& input : op->inputs) {
+ if (arrays_already_produced.count(input) ||
+ input == dump_options.graphviz_first_array) {
+ continue;
+ }
+ arrays_to_produce.push_back(input);
+ }
+ }
+ } else {
+ // Return the whole graph.
+ for (const auto& op : model.operators) {
+ ops_to_dump.push_back(op.get());
+ }
+ }
+ return ops_to_dump;
+}
+
+} // namespace
+
+void DumpGraphviz(const Model& model, string* output_file_contents) {
+ AppendF(output_file_contents, "digraph Computegraph {\n");
+
+ constexpr char kNodeFormat[] =
+ "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", "
+ "fontcolor = \"#%sDD\"];\n";
+
+ constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n";
+
+ constexpr char kRNNBackEdgeFormat[] =
+ "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
+
+ std::vector<const Operator*> ops_to_dump = OperatorsToDump(model);
+ std::set<string> already_added_arrays;
+ for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) {
+ const Operator& op = *ops_to_dump[op_index];
+ // Add node for operator.
+ auto op_properties = GetPropertiesForOperator(op);
+ string operator_id = StringF("op%05d", op_index);
+ AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label,
+ "box", op_properties.color.FillColorString().c_str(),
+ op_properties.color.TextColorString().c_str());
+ // Add nodes and edges for all inputs of the operator.
+ for (const auto& input : op.inputs) {
+ auto array_properties = GetPropertiesForArray(model, input);
+ if (!already_added_arrays.count(input)) {
+ AppendF(output_file_contents, kNodeFormat, input,
+ array_properties.label, "octagon",
+ array_properties.color.FillColorString().c_str(),
+ array_properties.color.TextColorString().c_str());
+ }
+ AppendF(output_file_contents, kEdgeFormat, input, operator_id);
+ already_added_arrays.insert(input);
+ }
+ // Add nodes and edges for all outputs of the operator.
+ for (const auto& output : op.outputs) {
+ auto array_properties = GetPropertiesForArray(model, output);
+ if (!already_added_arrays.count(output)) {
+ AppendF(output_file_contents, kNodeFormat, output,
+ array_properties.label, "octagon",
+ array_properties.color.FillColorString().c_str(),
+ array_properties.color.TextColorString().c_str());
+ }
+ AppendF(output_file_contents, kEdgeFormat, operator_id, output);
+ already_added_arrays.insert(output);
+ }
+ }
+
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ AppendF(output_file_contents, kRNNBackEdgeFormat,
+ rnn_state.back_edge_source_array(), rnn_state.state_array());
+ }
+
+ AppendF(output_file_contents, "}\n");
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/contrib/lite/toco/dump_graphviz.h
new file mode 100644
index 0000000000..0fb28e3de8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+void DumpGraphviz(const Model& model, string* output_file_contents);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
new file mode 100644
index 0000000000..16b9fa2260
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -0,0 +1,1570 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "google/protobuf/map.h"
+#include "google/protobuf/text_format.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::GraphDef;
+using tensorflow::TensorProto;
+
+namespace toco {
+namespace {
+
+// TensorFlow sometimes forbids what it calls "legacy scalars",
+// which are 1-D shapes where the unique shape size is 1.
+// See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
+// For that reason, we generally avoid creating legacy scalars,
+// by detecting the case where a 1-D shape would be of size 1 and
+// replacing that by a 0-D shape.
+// However, there is a special circumstance where we must not do that
+// and must unconditionally create a 1-D shape even if it is going to
+// be of size 1: that is the case of bias vectors, with BiasAdd nodes.
+// Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
+// a depth of 1, that would be a legacy scalar, so in that case we
+// must go ahead and keep the shape 1-D, letting it be a legacy scalar.
+enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
+
+void ExportFloatArray(const Shape& input_shape, const float* input_data,
+ TensorProto* output_tensor,
+ LegacyScalarPolicy legacy_scalar_policy) {
+ output_tensor->set_dtype(DT_FLOAT);
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+ auto* shape = output_tensor->mutable_tensor_shape();
+
+ const int kDims = input_shape.dimensions_count();
+ if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
+ kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
+ for (int i = 0; i < kDims; ++i) {
+ shape->add_dim()->set_size(input_shape.dims(i));
+ }
+ }
+ output_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(input_data),
+ sizeof(*input_data) * input_flat_size));
+}
+
+void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
+ const float* input_data, AxesOrder output_axes_order,
+ TensorProto* output_tensor,
+ LegacyScalarPolicy legacy_scalar_policy) {
+ CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
+ output_tensor->set_dtype(DT_FLOAT);
+ CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+
+ Shape shuffled_shape;
+ ShuffleDims(input_shape, input_axes_order, output_axes_order,
+ &shuffled_shape);
+ std::vector<float> shuffled_data(input_flat_size);
+ ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
+ input_data, shuffled_data.data());
+
+ ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
+ legacy_scalar_policy);
+}
+
+bool HasAlreadyExportedConst(const string& name,
+ const GraphDef& tensorflow_graph) {
+ for (const auto& node : tensorflow_graph.node()) {
+ if (node.op() == "Const" && node.name() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
+ const float* input_data,
+ AxesOrder input_axes_order,
+ AxesOrder output_axes_order,
+ GraphDef* tensorflow_graph,
+ LegacyScalarPolicy legacy_scalar_policy) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
+ tensor, legacy_scalar_policy);
+}
+
+void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
+ const float* input_data,
+ AxesOrder input_axes_order,
+ AxesOrder output_axes_order,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
+ tensor, LegacyScalarPolicy::kAvoidLegacyScalars);
+}
+
+void ConvertFloatTensorConst(const Model& model, const string& name,
+ AxesOrder input_axes_order,
+ AxesOrder output_axes_order,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ CHECK(model.arrays.count(name));
+ const auto& input_array = *model.arrays.at(name);
+ const auto& input_shape = input_array.shape();
+ CHECK(input_array.buffer);
+ CHECK(input_array.buffer->type == ArrayDataType::kFloat);
+ const float* input_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
+ tensor, LegacyScalarPolicy::kAvoidLegacyScalars);
+}
+
+void ConvertFloatTensorConst(const Model& model, const string& name,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ CHECK(model.arrays.count(name));
+ const auto& input_array = *model.arrays.at(name);
+ const auto& input_shape = input_array.shape();
+ CHECK(input_array.buffer);
+ CHECK(input_array.buffer->type == ArrayDataType::kFloat);
+ const float* input_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ExportFloatArray(input_shape, input_data, tensor,
+ LegacyScalarPolicy::kAvoidLegacyScalars);
+}
+
+void ConvertIntTensorConst(const Model& model, const string& name,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ CHECK(model.arrays.count(name));
+ const auto& array = *model.arrays.at(name);
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (auto index : data) {
+ tensor->add_int_val(index);
+ }
+ const auto& array_shape = array.shape();
+ auto* shape = tensor->mutable_tensor_shape();
+ for (int i = 0; i < array_shape.dimensions_count(); i++) {
+ shape->add_dim()->set_size(array_shape.dims(i));
+ }
+}
+
+void CreateMatrixShapeTensorConst(const string& name, int rows, int cols,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ const int32 data[2] = {cols, rows};
+ tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(data), sizeof(data)));
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(2);
+}
+
+void CreateDummyConcatDimTensorConst(const string& name, int dim,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ tensor->add_int_val(dim);
+}
+
+void CreateReshapeShapeTensorConst(const string& name,
+ const std::vector<int32>& shape,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ auto* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+ for (auto s : shape) {
+ tensor->add_int_val(s);
+ }
+ // TensorFlow sometimes forbids what it calls "legacy scalars",
+ // which are shapes of size 1 where the unique shape size is 1.
+ // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
+ if (shape.size() > 1) {
+ auto* tensor_shape = tensor->mutable_tensor_shape();
+ tensor_shape->add_dim()->set_size(shape.size());
+ }
+}
+
+string WalkUpToConstantArray(const Model& model, const string& name) {
+ const Array& original_array = model.GetArray(name);
+ if (original_array.buffer) {
+ return name;
+ }
+ const auto* op = GetOpWithOutput(model, name);
+ CHECK(op);
+ CHECK(op->type == OperatorType::kFakeQuant);
+ const string& input_of_fakequant_name = op->inputs[0];
+ const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
+ CHECK(input_of_fakequant.buffer);
+ return input_of_fakequant_name;
+}
+
+void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const bool has_bias = src_op.inputs.size() >= 3;
+ string conv_output = src_op.outputs[0];
+ if (has_bias) {
+ conv_output += "/conv";
+ }
+
+ auto* conv2d_op = tensorflow_graph->add_node();
+ conv2d_op->set_op("Conv2D");
+ conv2d_op->set_name(conv_output);
+ *conv2d_op->add_input() = src_op.inputs[0];
+ *conv2d_op->add_input() = src_op.inputs[1];
+ (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ const string& weights_array_name =
+ WalkUpToConstantArray(model, src_op.inputs[1]);
+ const auto& weights_array = model.GetArray(weights_array_name);
+ CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
+ ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
+ AxesOrder::kHWIO, tensorflow_graph);
+ auto& strides = (*conv2d_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*conv2d_op->mutable_attr())["padding"].set_s(padding);
+
+ if (has_bias) {
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(src_op.outputs[0]);
+ biasadd_op->add_input(conv_output);
+ biasadd_op->add_input(src_op.inputs[2]);
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ CHECK(model.arrays.count(src_op.inputs[2]));
+ const string& bias_array_name =
+ WalkUpToConstantArray(model, src_op.inputs[2]);
+ const auto& bias_array = model.GetArray(bias_array_name);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
+ AxesOrder::kOneAxis, AxesOrder::kOneAxis,
+ tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+ }
+}
+
+void ConvertDepthwiseConvOperator(const Model& model,
+ const DepthwiseConvOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const bool has_bias = src_op.inputs.size() >= 3;
+ string conv_output = src_op.outputs[0];
+ if (has_bias) {
+ conv_output += "/conv";
+ }
+
+ auto* dc2d_op = tensorflow_graph->add_node();
+ dc2d_op->set_op("DepthwiseConv2dNative");
+ dc2d_op->set_name(conv_output);
+ *dc2d_op->add_input() = src_op.inputs[0];
+ *dc2d_op->add_input() = src_op.inputs[1];
+ (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
+ // We need to convert that to H x W x InputDepth x Multiplier.
+ // That's only a matter of constructing a Dims object; the actual
+ // array layout is the same.
+ CHECK(model.arrays.count(src_op.inputs[1]));
+ const string& src_weights_name =
+ WalkUpToConstantArray(model, src_op.inputs[1]);
+ const auto& src_weights_array = model.GetArray(src_weights_name);
+ const auto& src_weights_shape = src_weights_array.shape();
+ CHECK_EQ(src_weights_shape.dimensions_count(), 4);
+ const Shape dst_weights_shape =
+ Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
+ src_weights_shape.dims(3) / src_op.depth_multiplier,
+ src_op.depth_multiplier});
+ CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
+ CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
+ src_weights_shape.dims(3));
+ CHECK_EQ(src_weights_shape.dims(0), 1);
+
+ CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
+ const float* src_weights_data =
+ src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
+ AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
+
+ auto& strides = (*dc2d_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*dc2d_op->mutable_attr())["padding"].set_s(padding);
+
+ if (has_bias) {
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(src_op.outputs[0]);
+ biasadd_op->add_input(conv_output);
+ biasadd_op->add_input(src_op.inputs[2]);
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ CHECK(model.arrays.count(src_op.inputs[2]));
+ const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
+ const auto& bias_array = model.GetArray(bias_name);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
+ AxesOrder::kOneAxis, AxesOrder::kOneAxis,
+ tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+ }
+}
+
+void ConvertDepthToSpaceOperator(const Model& model,
+ const DepthToSpaceOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op("DepthToSpace");
+ op->set_name(src_op.outputs[0]);
+ *op->add_input() = src_op.inputs[0];
+ (*op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
+}
+
+void ConvertSpaceToDepthOperator(const Model& model,
+ const SpaceToDepthOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op("SpaceToDepth");
+ op->set_name(src_op.outputs[0]);
+ *op->add_input() = src_op.inputs[0];
+ (*op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
+}
+
+void ConvertFullyConnectedOperator(const Model& model,
+ const FullyConnectedOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string reshape_output = src_op.outputs[0] + "/reshape";
+ const string reshape_shape = src_op.outputs[0] + "/reshape/shape";
+ auto* reshape_op = tensorflow_graph->add_node();
+ reshape_op->set_op("Reshape");
+ reshape_op->set_name(reshape_output);
+ reshape_op->add_input(src_op.inputs[0]);
+ reshape_op->add_input(reshape_shape);
+ (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const bool has_bias = src_op.inputs.size() >= 3;
+ string matmul_output = src_op.outputs[0];
+ if (has_bias) {
+ matmul_output += "/matmul";
+ }
+
+ auto* matmul_op = tensorflow_graph->add_node();
+ matmul_op->set_op("MatMul");
+
+ matmul_op->set_name(matmul_output);
+ *matmul_op->add_input() = reshape_output;
+ *matmul_op->add_input() = src_op.inputs[1];
+ (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
+ (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
+ CHECK(model.arrays.count(src_op.inputs[1]));
+ const string& fc_weights_name =
+ WalkUpToConstantArray(model, src_op.inputs[1]);
+ const auto& fc_weights_array = *model.arrays.at(fc_weights_name);
+ const auto& fc_weights_shape = fc_weights_array.shape();
+ CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
+ CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
+ tensorflow_graph);
+
+ CHECK(fc_weights_array.buffer);
+ CHECK(fc_weights_array.buffer->type == ArrayDataType::kFloat);
+ const float* fc_weights_data =
+ fc_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(fc_weights_name, fc_weights_shape, fc_weights_data,
+ AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
+
+ if (has_bias) {
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(src_op.outputs[0]);
+ biasadd_op->add_input(matmul_output);
+ biasadd_op->add_input(src_op.inputs[2]);
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ CHECK(model.arrays.count(src_op.inputs[2]));
+ const auto& bias_array = *model.arrays.at(src_op.inputs[2]);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
+ bias_shape_1d, bias_data, AxesOrder::kOneAxis,
+ AxesOrder::kOneAxis, tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+ }
+}
+
+void ConvertAddOperator(const Model& model, const AddOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* add_op = tensorflow_graph->add_node();
+ add_op->set_op("Add");
+ add_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *add_op->add_input() = src_op.inputs[0];
+ *add_op->add_input() = src_op.inputs[1];
+ (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertMulOperator(const Model& model, const MulOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* add_op = tensorflow_graph->add_node();
+ add_op->set_op("Mul");
+ add_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *add_op->add_input() = src_op.inputs[0];
+ *add_op->add_input() = src_op.inputs[1];
+ (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertReluOperator(const ReluOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* relu_op = tensorflow_graph->add_node();
+ relu_op->set_op("Relu");
+ relu_op->set_name(src_op.outputs[0]);
+ *relu_op->add_input() = src_op.inputs[0];
+ (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertRelu1Operator(const Relu1Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string max_bounds = src_op.outputs[0] + "/max_bounds";
+ const string min_bounds = src_op.outputs[0] + "/min_bounds";
+ const string max_output = src_op.outputs[0] + "/max_output";
+
+ auto* max_bounds_const_op = tensorflow_graph->add_node();
+ max_bounds_const_op->set_op("Const");
+ max_bounds_const_op->set_name(max_bounds);
+ (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* max_bounds_const_op_tensor =
+ (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
+ max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
+ max_bounds_const_op_tensor->add_float_val(-1.0f);
+
+ auto* min_bounds_const_op = tensorflow_graph->add_node();
+ min_bounds_const_op->set_op("Const");
+ min_bounds_const_op->set_name(min_bounds);
+ (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ auto* min_bounds_const_op_tensor =
+ (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
+ min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
+ min_bounds_const_op_tensor->add_float_val(1.0f);
+
+ auto* max_op = tensorflow_graph->add_node();
+ max_op->set_op("Maximum");
+ max_op->set_name(max_output);
+ *max_op->add_input() = src_op.inputs[0];
+ *max_op->add_input() = max_bounds;
+ (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* min_op = tensorflow_graph->add_node();
+ min_op->set_op("Minimum");
+ min_op->set_name(src_op.outputs[0]);
+ *min_op->add_input() = max_output;
+ *min_op->add_input() = min_bounds;
+ (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertRelu6Operator(const Relu6Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* relu_op = tensorflow_graph->add_node();
+ relu_op->set_op("Relu6");
+ relu_op->set_name(src_op.outputs[0]);
+ *relu_op->add_input() = src_op.inputs[0];
+ (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertLogisticOperator(const LogisticOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* relu_op = tensorflow_graph->add_node();
+ relu_op->set_op("Sigmoid");
+ relu_op->set_name(src_op.outputs[0]);
+ *relu_op->add_input() = src_op.inputs[0];
+ (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertTanhOperator(const TanhOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* tanh_op = tensorflow_graph->add_node();
+ tanh_op->set_op("Tanh");
+ tanh_op->set_name(src_op.outputs[0]);
+ *tanh_op->add_input() = src_op.inputs[0];
+ (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ string softmax_input;
+ Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
+ if (providing_op->type == OperatorType::kTensorFlowReshape) {
+ softmax_input = src_op.inputs[0];
+ } else {
+ // Insert a reshape operator that reduces the dimensions down to the 2 that
+ // are required for TensorFlow Logits.
+ const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape";
+ const string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
+ softmax_input = reshape_output;
+
+ auto* reshape_op = tensorflow_graph->add_node();
+ reshape_op->set_op("Reshape");
+ reshape_op->set_name(reshape_output);
+ *reshape_op->add_input() = src_op.inputs[0];
+ *reshape_op->add_input() = softmax_size;
+ (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape();
+ int32 flattened_size = 1;
+ for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
+ flattened_size *= input_shape.dims(i);
+ }
+ const std::vector<int32> shape_data = {
+ flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
+ CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
+ }
+
+ auto* softmax_op = tensorflow_graph->add_node();
+ softmax_op->set_op("Softmax");
+ softmax_op->set_name(src_op.outputs[0]);
+ *softmax_op->add_input() = softmax_input;
+ // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
+ CHECK_EQ(src_op.beta, 1.f);
+ (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string square_output = src_op.outputs[0] + "/square";
+ const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices";
+ const string sum_output = src_op.outputs[0] + "/sum";
+ const string rsqrt_output = src_op.outputs[0] + "/rsqrt";
+ const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
+
+ auto* sum_reduction_indices_op = tensorflow_graph->add_node();
+ sum_reduction_indices_op->set_op("Const");
+ sum_reduction_indices_op->set_name(sum_reduction_indices);
+ (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* sum_reduction_indices_tensor =
+ (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
+ sum_reduction_indices_tensor->set_dtype(DT_INT32);
+ auto* sum_reduction_indices_shape =
+ sum_reduction_indices_tensor->mutable_tensor_shape();
+ auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
+ sum_reduction_indices_dim->set_size(2);
+ sum_reduction_indices_tensor->add_int_val(0);
+ sum_reduction_indices_tensor->add_int_val(1);
+
+ auto* square_op = tensorflow_graph->add_node();
+ square_op->set_op("Square");
+ square_op->set_name(square_output);
+ *square_op->add_input() = src_op.inputs[0];
+ (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* sum_op = tensorflow_graph->add_node();
+ sum_op->set_op("Sum");
+ sum_op->set_name(sum_output);
+ *sum_op->add_input() = square_output;
+ *sum_op->add_input() = sum_reduction_indices;
+ (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* rsqrt_op = tensorflow_graph->add_node();
+ rsqrt_op->set_op("Rsqrt");
+ rsqrt_op->set_name(rsqrt_output);
+ *rsqrt_op->add_input() = sum_output;
+ (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ auto* mul_op = tensorflow_graph->add_node();
+ mul_op->set_op("Mul");
+ mul_op->set_name(src_op.outputs[0]);
+ *mul_op->add_input() = src_op.inputs[0];
+ *mul_op->add_input() = rsqrt_output;
+ (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertLocalResponseNormalizationOperator(
+ const LocalResponseNormalizationOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* lrn_op = tensorflow_graph->add_node();
+ lrn_op->set_op("LRN");
+ lrn_op->set_name(src_op.outputs[0]);
+ *lrn_op->add_input() = src_op.inputs[0];
+ (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
+ (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
+ (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
+ (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
+}
+
+void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* fakequant_op = tensorflow_graph->add_node();
+ fakequant_op->set_op("FakeQuantWithMinMaxArgs");
+ fakequant_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *fakequant_op->add_input() = src_op.inputs[0];
+ CHECK(src_op.minmax);
+ (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
+ (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
+}
+
+void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* maxpool_op = tensorflow_graph->add_node();
+ maxpool_op->set_op("MaxPool");
+ maxpool_op->set_name(src_op.outputs[0]);
+ *maxpool_op->add_input() = src_op.inputs[0];
+ auto& strides = (*maxpool_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*maxpool_op->mutable_attr())["padding"].set_s(padding);
+ (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
+ ksize.mutable_list()->add_i(1);
+ ksize.mutable_list()->add_i(src_op.kheight);
+ ksize.mutable_list()->add_i(src_op.kwidth);
+ ksize.mutable_list()->add_i(1);
+}
+
+void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* avgpool_op = tensorflow_graph->add_node();
+ avgpool_op->set_op("AvgPool");
+ avgpool_op->set_name(src_op.outputs[0]);
+ *avgpool_op->add_input() = src_op.inputs[0];
+ auto& strides = (*avgpool_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ (*avgpool_op->mutable_attr())["padding"].set_s(padding);
+ (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
+ ksize.mutable_list()->add_i(1);
+ ksize.mutable_list()->add_i(src_op.kheight);
+ ksize.mutable_list()->add_i(src_op.kwidth);
+ ksize.mutable_list()->add_i(1);
+}
+
+void ConvertConcatenationOperator(const Model& model,
+ const ConcatenationOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* dc_op = tensorflow_graph->add_node();
+ dc_op->set_op("ConcatV2");
+ dc_op->set_name(src_op.outputs[0]);
+ const string dummy_concat_dim = src_op.outputs[0] + "/concat_dim";
+ CreateDummyConcatDimTensorConst(dummy_concat_dim, src_op.concat_dim,
+ tensorflow_graph);
+ for (const auto& input : src_op.inputs) {
+ *dc_op->add_input() = input;
+ }
+ *dc_op->add_input() = dummy_concat_dim;
+ (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
+ (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
+}
+
+void ConvertTensorFlowReshapeOperator(const Model& model,
+ const TensorFlowReshapeOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* reshape_op = tensorflow_graph->add_node();
+ reshape_op->set_op("Reshape");
+ reshape_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *reshape_op->add_input() = src_op.inputs[0];
+ *reshape_op->add_input() = src_op.inputs[1];
+ (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ const auto& shape_array = model.GetArray(src_op.inputs[1]);
+ CHECK(shape_array.data_type == ArrayDataType::kInt32);
+ CHECK(shape_array.buffer != nullptr);
+ const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
+}
+
+void ConvertL2PoolOperator(const L2PoolOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ const string square_output = src_op.outputs[0] + "/square";
+ const string avgpool_output = src_op.outputs[0] + "/avgpool";
+
+ auto* square_op = tensorflow_graph->add_node();
+ square_op->set_op("Square");
+ square_op->set_name(square_output);
+ *square_op->add_input() = src_op.inputs[0];
+ (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ string padding;
+ if (src_op.padding.type == PaddingType::kSame) {
+ padding = "SAME";
+ } else if (src_op.padding.type == PaddingType::kValid) {
+ padding = "VALID";
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+
+ auto* avgpool_op = tensorflow_graph->add_node();
+ avgpool_op->set_op("AvgPool");
+ avgpool_op->set_name(avgpool_output);
+ *avgpool_op->add_input() = square_output;
+ auto& strides = (*avgpool_op->mutable_attr())["strides"];
+ strides.mutable_list()->add_i(1);
+ strides.mutable_list()->add_i(src_op.stride_height);
+ strides.mutable_list()->add_i(src_op.stride_width);
+ strides.mutable_list()->add_i(1);
+
+ (*avgpool_op->mutable_attr())["padding"].set_s(padding);
+ (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
+ ksize.mutable_list()->add_i(1);
+ ksize.mutable_list()->add_i(src_op.kheight);
+ ksize.mutable_list()->add_i(src_op.kwidth);
+ ksize.mutable_list()->add_i(1);
+
+ auto* sqrt_op = tensorflow_graph->add_node();
+ sqrt_op->set_op("Sqrt");
+ sqrt_op->set_name(src_op.outputs[0]);
+ *sqrt_op->add_input() = avgpool_output;
+ (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* square_op = tensorflow_graph->add_node();
+ square_op->set_op("Square");
+ square_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *square_op->add_input() = src_op.inputs[0];
+ (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sqrt_op = tensorflow_graph->add_node();
+ sqrt_op->set_op("Sqrt");
+ sqrt_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *sqrt_op->add_input() = src_op.inputs[0];
+ (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSplitOperator(const Model& model,
+ const TensorFlowSplitOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* split_op = tensorflow_graph->add_node();
+ split_op->set_op("Split");
+ split_op->set_name(src_op.outputs[0]);
+ for (const auto& input : src_op.inputs) {
+ *split_op->add_input() = input;
+ }
+ (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
+ const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
+ CHECK(split_dim_array.buffer);
+ CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
+ const auto& split_dim_data =
+ split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(split_dim_data.size(), 1);
+ const int split_dim = split_dim_data[0];
+ CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
+ tensorflow_graph);
+}
+
+tensorflow::DataType GetTensorFlowDataType(const Model& model,
+ const string& array_name) {
+ auto& dtype = model.GetArray(array_name).data_type;
+ CHECK(dtype == ArrayDataType::kFloat || dtype == ArrayDataType::kInt32 ||
+ dtype == ArrayDataType::kUint8);
+ if (dtype == ArrayDataType::kFloat) {
+ return tensorflow::DT_FLOAT;
+ } else if (dtype == ArrayDataType::kInt32) {
+ return tensorflow::DT_INT32;
+ } else if (dtype == ArrayDataType::kUint8) {
+ return tensorflow::DT_UINT8;
+ } else {
+ LOG(FATAL) << "Wrong data type";
+ }
+}
+
+void ConvertCastOperator(const Model& model, const CastOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* cast_op = tensorflow_graph->add_node();
+ cast_op->set_op("Cast");
+ cast_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *cast_op->add_input() = src_op.inputs[0];
+
+ (*cast_op->mutable_attr())["DstT"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+ (*cast_op->mutable_attr())["SrcT"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+}
+
+void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* floor_op = tensorflow_graph->add_node();
+ floor_op->set_op("Floor");
+ floor_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *floor_op->add_input() = src_op.inputs[0];
+ (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* gather_op = tensorflow_graph->add_node();
+ gather_op->set_op("Gather");
+ gather_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *gather_op->add_input() = src_op.inputs[0];
+ *gather_op->add_input() = src_op.inputs[1];
+
+ (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
+}
+
+void ConvertResizeBilinearOperator(const Model& model,
+ const ResizeBilinearOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* resize_op = tensorflow_graph->add_node();
+ resize_op->set_op("ResizeBilinear");
+ resize_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *resize_op->add_input() = src_op.inputs[0];
+ *resize_op->add_input() = src_op.inputs[1];
+ (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+namespace {
+// TODO(aselle): Remove when available in absl
+absl::string_view FindLongestCommonPrefix(absl::string_view a,
+ absl::string_view b) {
+ if (a.empty() || b.empty()) return absl::string_view();
+
+ const char* pa = a.data();
+ const char* pb = b.data();
+ string::difference_type count = 0;
+ const string::difference_type limit = std::min(a.size(), b.size());
+ while (count < limit && *pa == *pb) {
+ ++pa;
+ ++pb;
+ ++count;
+ }
+
+ return absl::string_view(a.data(), count);
+}
+} // namespace
+
+void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ // Find the base name
+ const string base(
+ FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
+ src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
+
+ // Concatenate inputs
+ const string concat_output = base + "basic_lstm_cell/concat";
+ // Op names have been chosen to match the tf.slim LSTM naming
+ // as closely as possible.
+ const int concat_dim =
+ model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
+ ->shape()
+ .dimensions_count() -
+ 1;
+ // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
+ // works the same since the tensor has the same underlying data layout.
+ const string concat_dim_output = concat_output + "/concat_dim";
+ CreateDummyConcatDimTensorConst(concat_dim_output, concat_dim,
+ tensorflow_graph);
+ auto* concat_op = tensorflow_graph->add_node();
+ concat_op->set_op("ConcatV2");
+ concat_op->set_name(concat_output);
+ *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
+ *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
+ *concat_op->add_input() = concat_dim_output;
+ (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
+ (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
+
+ // Write weights
+ const string weights_output = base + "weights";
+ CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
+ const auto& weights_array =
+ *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
+ // Convert 4D FullyConnected weights into 2D matrix
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+ CHECK(weights_array.buffer);
+ CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
+ const float* weights_data =
+ weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
+ AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
+
+ // Fully connected matrix multiply
+ const string matmul_output = base + "MatMul";
+ auto* matmul_op = tensorflow_graph->add_node();
+ matmul_op->set_op("MatMul");
+ matmul_op->set_name(matmul_output);
+ *matmul_op->add_input() = concat_output;
+ *matmul_op->add_input() = weights_output;
+ (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
+ (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
+ (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ // Write biases
+ const string biases_output = base + "biases";
+ CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
+ const auto& bias_array =
+ *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
+ // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
+ Shape bias_shape_1d = bias_array.shape();
+ UnextendShape(&bias_shape_1d, 1);
+ CHECK(bias_array.buffer);
+ CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
+ const float* bias_data =
+ bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
+ ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
+ AxesOrder::kOneAxis, AxesOrder::kOneAxis,
+ tensorflow_graph,
+ LegacyScalarPolicy::kDoCreateLegacyScalars);
+
+ // Add biases
+ string biasadd_output = base + "BiasAdd";
+ auto* biasadd_op = tensorflow_graph->add_node();
+ biasadd_op->set_op("BiasAdd");
+ biasadd_op->set_name(biasadd_output);
+ biasadd_op->add_input(matmul_output);
+ biasadd_op->add_input(biases_output);
+ (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
+ (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ // Split
+ string split_dim_output = base + "split/split_dim";
+ // The dimension is the same as the concatenation dimension
+ CreateDummyConcatDimTensorConst(split_dim_output, concat_dim,
+ tensorflow_graph);
+ string split_output = base + "split";
+ auto* split_op = tensorflow_graph->add_node();
+ split_op->set_op("Split");
+ split_op->set_name(split_output);
+ *split_op->add_input() = split_dim_output;
+ *split_op->add_input() = biasadd_output;
+ (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs
+
+ // Activation functions and memory computations
+ const string tanh_0_output = base + "Tanh";
+ auto* tanh_0_op = tensorflow_graph->add_node();
+ tanh_0_op->set_op("Tanh");
+ tanh_0_op->set_name(tanh_0_output);
+ *tanh_0_op->add_input() = split_output + ":1";
+ (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string sigmoid_1_output = base + "Sigmoid_1";
+ auto* logistic_1_op = tensorflow_graph->add_node();
+ logistic_1_op->set_op("Sigmoid");
+ logistic_1_op->set_name(sigmoid_1_output);
+ *logistic_1_op->add_input() = split_output;
+ (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string mul_1_output = base + "mul_1";
+ auto* mul_1_op = tensorflow_graph->add_node();
+ mul_1_op->set_op("Mul");
+ mul_1_op->set_name(mul_1_output);
+ *mul_1_op->add_input() = sigmoid_1_output;
+ *mul_1_op->add_input() = tanh_0_output;
+ (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string sigmoid_0_output = base + "Sigmoid";
+ auto* logistic_2_op = tensorflow_graph->add_node();
+ logistic_2_op->set_op("Sigmoid");
+ logistic_2_op->set_name(sigmoid_0_output);
+ *logistic_2_op->add_input() = split_output + ":2";
+ (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string sigmoid_2_output = base + "Sigmoid_2";
+ auto* logistic_3_op = tensorflow_graph->add_node();
+ logistic_3_op->set_op("Sigmoid");
+ logistic_3_op->set_name(sigmoid_2_output);
+ *logistic_3_op->add_input() = split_output + ":3";
+ (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string mul_0_output = base + "mul";
+ auto* mul_0_op = tensorflow_graph->add_node();
+ mul_0_op->set_op("Mul");
+ mul_0_op->set_name(mul_0_output);
+ *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
+ *mul_0_op->add_input() = sigmoid_0_output;
+ (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT];
+ auto* add_1_op = tensorflow_graph->add_node();
+ add_1_op->set_op("Add");
+ add_1_op->set_name(add_1_output);
+ *add_1_op->add_input() = mul_0_output;
+ *add_1_op->add_input() = mul_1_output;
+ (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string tanh_1_output = base + "Tanh_1";
+ auto* tanh_1_op = tensorflow_graph->add_node();
+ tanh_1_op->set_op("Tanh");
+ tanh_1_op->set_name(tanh_1_output);
+ *tanh_1_op->add_input() = add_1_output;
+ (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
+
+ const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
+ auto* mul_2_op = tensorflow_graph->add_node();
+ mul_2_op->set_op("Mul");
+ mul_2_op->set_name(mul_2_output);
+ *mul_2_op->add_input() = tanh_1_output;
+ *mul_2_op->add_input() = sigmoid_2_output;
+ (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
+void ConvertSpaceToBatchNDOperator(const Model& model,
+ const SpaceToBatchNDOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("SpaceToBatchND");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
+ (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
+}
+
+void ConvertBatchToSpaceNDOperator(const Model& model,
+ const BatchToSpaceNDOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("BatchToSpaceND");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
+ (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
+}
+
+void ConvertPadOperator(const Model& model, const PadOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Pad");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
+ for (int i = 0; i < src_op.left_padding.size(); ++i) {
+ tensor->add_int_val(src_op.left_padding[i]);
+ tensor->add_int_val(src_op.right_padding[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.left_padding.size());
+ shape->add_dim()->set_size(2);
+}
+
+void CreateSliceInput(const string& input_name, const std::vector<int>& values,
+ GraphDef* tensorflow_graph) {
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(input_name);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ for (int i = 0; i < values.size(); ++i) {
+ tensor->add_int_val(values[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(values.size());
+}
+
+void ConvertStridedSliceOperator(const Model& model,
+ const StridedSliceOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("StridedSlice");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 4);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+ *new_op->add_input() = src_op.inputs[3];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
+ (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
+ (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
+ (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
+ (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
+ (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
+
+ // Create tensors for start/stop indices and strides.
+ CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
+ CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
+ CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
+}
+
+void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Slice");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
+
+ // Create tensors for begin and size inputs.
+ CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
+ CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
+}
+
+void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Mean");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ for (int i = 0; i < src_op.reduction_indices.size(); ++i) {
+ tensor->add_int_val(src_op.reduction_indices[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.reduction_indices.size());
+}
+
+void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("Squeeze");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *new_op->add_input() = src_op.inputs[0];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
+ for (int i : src_op.squeeze_dims) {
+ squeeze_dims.mutable_list()->add_i(i);
+ }
+}
+
+void ConvertSubOperator(const Model& model, const SubOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Sub");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
+void ConvertTensorFlowMinimumOperator(const Model& model,
+ const TensorFlowMinimumOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Minimum");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
+void ConvertTensorFlowMaximumOperator(const Model& model,
+ const TensorFlowMaximumOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Maximum");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
+void ConvertOperator(const Model& model, const Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
+ LOG(FATAL)
+ << "Unsupported: the input model has a fused activation function";
+ }
+
+ if (src_op.type == OperatorType::kConv) {
+ ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kDepthwiseConv) {
+ ConvertDepthwiseConvOperator(
+ model, static_cast<const DepthwiseConvOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kDepthToSpace) {
+ ConvertDepthToSpaceOperator(
+ model, static_cast<const DepthToSpaceOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSpaceToDepth) {
+ ConvertSpaceToDepthOperator(
+ model, static_cast<const SpaceToDepthOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kFullyConnected) {
+ ConvertFullyConnectedOperator(
+ model, static_cast<const FullyConnectedOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAdd) {
+ ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kMul) {
+ ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRelu) {
+ ConvertReluOperator(static_cast<const ReluOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRelu1) {
+ ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRelu6) {
+ ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogistic) {
+ ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTanh) {
+ ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kL2Normalization) {
+ ConvertL2NormalizationOperator(
+ static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSoftmax) {
+ ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
+ ConvertLocalResponseNormalizationOperator(
+ static_cast<const LocalResponseNormalizationOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLstmCell) {
+ ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kMaxPool) {
+ ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAveragePool) {
+ ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kConcatenation) {
+ ConvertConcatenationOperator(
+ model, static_cast<const ConcatenationOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowReshape) {
+ ConvertTensorFlowReshapeOperator(
+ model, static_cast<const TensorFlowReshapeOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kL2Pool) {
+ ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowSquare) {
+ ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowSqrt) {
+ ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowSplit) {
+ ConvertSplitOperator(model,
+ static_cast<const TensorFlowSplitOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kFakeQuant) {
+ ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kCast) {
+ ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kFloor) {
+ ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kGather) {
+ ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kResizeBilinear) {
+ ConvertResizeBilinearOperator(
+ model, static_cast<const ResizeBilinearOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSpaceToBatchND) {
+ ConvertSpaceToBatchNDOperator(
+ model, static_cast<const SpaceToBatchNDOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kBatchToSpaceND) {
+ ConvertBatchToSpaceNDOperator(
+ model, static_cast<const BatchToSpaceNDOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPad) {
+ ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kStridedSlice) {
+ ConvertStridedSliceOperator(
+ model, static_cast<const StridedSliceOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kMean) {
+ ConvertMeanOperator(model, static_cast<const MeanOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSub) {
+ ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowMinimum) {
+ ConvertTensorFlowMinimumOperator(
+ model, static_cast<const TensorFlowMinimumOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowMaximum) {
+ ConvertTensorFlowMaximumOperator(
+ model, static_cast<const TensorFlowMaximumOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSqueeze) {
+ ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSlice) {
+ ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
+ tensorflow_graph);
+ } else {
+ LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
+ }
+}
+
+void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) {
+ auto* placeholder = tensorflow_graph->add_node();
+ placeholder->set_op("Placeholder");
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
+ placeholder->set_name(name);
+}
+
+void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
+ GraphDef* tensorflow_graph) {
+ auto* placeholder = tensorflow_graph->add_node();
+ placeholder->set_op("Placeholder");
+ placeholder->set_name(name);
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
+
+ auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
+ const auto& state_array = *model.arrays.at(name);
+ if (state_array.has_shape()) {
+ const auto& state_shape = state_array.shape();
+ const int kDims = state_shape.dimensions_count();
+ for (int i = 0; i < kDims; ++i) {
+ shape->add_dim()->set_size(state_shape.dims(i));
+ }
+ } else {
+ shape->add_dim()->set_size(1);
+ shape->add_dim()->set_size(size);
+ }
+}
+
+void ExportTensorFlowGraphDefImplementation(const Model& model,
+ GraphDef* tensorflow_graph) {
+ for (const auto& input_array : model.flags.input_arrays()) {
+ AddPlaceholder(input_array.name(), tensorflow_graph);
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
+ tensorflow_graph);
+ }
+ for (const auto& op : model.operators) {
+ ConvertOperator(model, *op, tensorflow_graph);
+ }
+ // Generically export arrays that haven't been exported already
+ // by the above operators export. It's important that this comes
+ // after, as some operators need to export arrays that they reference
+ // in a specific way, rather than in the generic way done below.
+ for (const auto& array_pair : model.arrays) {
+ const string& array_name = array_pair.first;
+ const auto& array = *array_pair.second;
+ if (array.buffer) {
+ switch (array.data_type) {
+ case ArrayDataType::kFloat:
+ ConvertFloatTensorConst(model, array_name, tensorflow_graph);
+ break;
+ case ArrayDataType::kInt32:
+ ConvertIntTensorConst(model, array_name, tensorflow_graph);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+}
+} // namespace
+
+void ExportTensorFlowGraphDef(const Model& model,
+ string* output_file_contents) {
+ CHECK(output_file_contents->empty());
+ GraphDef tensorflow_graph;
+ ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
+ LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
+ CHECK(tensorflow_graph.SerializeToString(output_file_contents));
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h
new file mode 100644
index 0000000000..eca9774576
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.h
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+
+#include <string>
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h
new file mode 100644
index 0000000000..3bc3295d04
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/format_port.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file is used to provide equivalents of internal util::format::FormatF
+// and util::format::AppendF. Unfortunately, type safety is not as good as a
+// a full C++ example.
+// TODO(aselle): When absl adds support for StrFormat, use that instead.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace toco {
+namespace port {
+
+/// Identity (default case)
+template <class T>
+T IdentityOrConvertStringToRaw(T foo) {
+ return foo;
+}
+
+// Overloaded case where we return std::string.
+inline const char* IdentityOrConvertStringToRaw(const std::string& foo) {
+ return foo.c_str();
+}
+
+#if defined(PLATFORM_GOOGLE)
+// Overloaded case where we return string.
+inline const char* IdentityOrConvertStringToRaw(const string& foo) {
+ return foo.c_str();
+}
+#endif // PLATFORM_GOOGLE
+// Delegate to TensorFlow Appendf function until absl has an equivalent.
+template <typename... Args>
+inline void AppendFHelper(string* destination, const char* fmt,
+ Args&&... args) {
+ tensorflow::strings::Appendf(destination, fmt, args...);
+}
+
+// Specialization for no argument format string (avoid security bug).
+inline void AppendFHelper(string* destination, const char* fmt) {
+ tensorflow::strings::Appendf(destination, "%s", fmt);
+}
+
+// Append formatted string (with format fmt and args args) to the string
+// pointed to by destination. fmt follows C printf semantics.
+// One departure is that %s can be driven by a std::string or string.
+template <typename... Args>
+inline void AppendF(string* destination, const char* fmt, Args&&... args) {
+ AppendFHelper(destination, fmt, IdentityOrConvertStringToRaw(args)...);
+}
+
+// Return formatted string (with format fmt and args args). fmt follows C printf
+// semantics. One departure is that %s can be driven by a std::string or string.
+template <typename... Args>
+inline string StringF(const char* fmt, Args&&... args) {
+ string result;
+ AppendFHelper(&result, fmt, IdentityOrConvertStringToRaw(args)...);
+ return result;
+}
+
+} // namespace port
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
new file mode 100644
index 0000000000..bf454c40c7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->stride_width != conv_op->stride_height) {
+ return false;
+ }
+ auto& weights_array = model->GetArray(conv_op->inputs[1]);
+ if (!weights_array.buffer) {
+ // Yield until the weights are resolved as a constant array.
+ return false;
+ }
+ if (weights_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (weights_array.shape().dims(3) != 1) {
+ // Not a pure convolution: Conv does accumulation across the depth
+ // dimension.
+ return false;
+ }
+ // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
+ AddMessageF(
+ "%s is purely convolutional (input/weights depth is 1), replacing it by "
+ "a DepthwiseConv.",
+ LogName(*conv_op));
+ auto* depthwiseconv_op = new DepthwiseConvOperator;
+ // Conv and DepthwiseConv take the same inputs
+ depthwiseconv_op->inputs = conv_op->inputs;
+ // Conv may have a 2nd output for im2col
+ depthwiseconv_op->outputs = {conv_op->outputs[0]};
+ if (conv_op->outputs.size() > 1) {
+ // delete the im2col array.
+ model->arrays.erase(conv_op->outputs[1]);
+ }
+ depthwiseconv_op->fused_activation_function =
+ conv_op->fused_activation_function;
+ // Let PropagateFixedSizes recompute fixed padding, just in case some day it
+ // may be different for Conv vs DepthwiseConv.
+ depthwiseconv_op->padding.type = conv_op->padding.type;
+ depthwiseconv_op->stride_height = conv_op->stride_height;
+ depthwiseconv_op->stride_width = conv_op->stride_width;
+ depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
+ // Replace the operator in the graph.
+ const auto depthwiseconv_it =
+ model->operators.emplace(conv_it, depthwiseconv_op);
+ conv_it = depthwiseconv_it + 1;
+ CHECK_EQ(conv_it->get(), conv_op);
+ model->operators.erase(conv_it);
+ // Shuffle the weights.
+ const auto& weights_shape = weights_array.shape();
+ auto& weights_buffer =
+ weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ const std::vector<float>& conv_weights_data = weights_buffer.data;
+ std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
+ const int depth = weights_shape.dims(0);
+ const int width = weights_shape.dims(1);
+ const int height = weights_shape.dims(2);
+ const int width_height = width * height;
+ for (int c = 0; c < depth; c++) {
+ for (int xy = 0; xy < width_height; xy++) {
+ depthwise_conv_weights_data[c + depth * xy] =
+ conv_weights_data[xy + width_height * c];
+ }
+ }
+ *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
+ weights_buffer.data = depthwise_conv_weights_data;
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
new file mode 100644
index 0000000000..1735b51e5b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->outputs.size() == 2) {
+ // We already have an im2col array
+ return false;
+ }
+ const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+ if (!weights_array.has_shape()) {
+ // We need to yield until weights dims have been resolved, because
+ // from the weights dims we determine whether an im2col array is
+ // needed.
+ return false;
+ }
+ const auto& weights_shape = weights_array.shape();
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
+ conv_op->stride_height == 1) {
+ // 1x1 unstrided conv does not need an im2col array.
+ return false;
+ }
+
+ // Create the im2col array.
+ CHECK_EQ(conv_op->outputs.size(), 1);
+ const string& im2col_array_name =
+ AvailableArrayName(*model, conv_op->inputs[0] + "_im2col");
+ model->GetOrCreateArray(im2col_array_name);
+ conv_op->outputs.push_back(im2col_array_name);
+ AddMessageF(
+ "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, "
+ "stride_height=%d",
+ LogName(*conv_op), kwidth, kheight, conv_op->stride_width,
+ conv_op->stride_height);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
new file mode 100644
index 0000000000..b89e3f5310
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <ArrayDataType A>
+void DequantizeBuffer(Array* array) {
+ const auto old_data = array->GetBuffer<A>().data;
+ array->buffer = nullptr;
+ array->data_type = ArrayDataType::kFloat;
+ auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
+ new_data.resize(old_data.size());
+ const auto& qparams = array->GetQuantizationParams();
+ for (int i = 0; i < old_data.size(); i++) {
+ new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
+ }
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
+ Model* model, const string& array_name) {
+ for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
+ for (const auto& input : it->get()->inputs) {
+ if (input == array_name) {
+ return it;
+ }
+ }
+ }
+ return model->operators.end();
+}
+
+void ClearArrayQuantizationParams(const string& array_name, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ CHECK(array->quantization_params);
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == array_name) {
+ auto& qparams = *array->quantization_params;
+ const double new_std_value = 1. / qparams.scale;
+ const double new_mean_value = qparams.zero_point;
+ if (input_array.has_std_value()) {
+ CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
+ } else {
+ input_array.set_std_value(new_std_value);
+ }
+ if (input_array.has_mean_value()) {
+ CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
+ } else {
+ input_array.set_mean_value(new_mean_value);
+ }
+ }
+ }
+ array->quantization_params = nullptr;
+}
+
+bool DequantizeArray(const string& array_name,
+ GraphTransformation* transformation, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ if (!array->quantization_params) {
+ return false;
+ }
+ transformation->AddMessageF("Dequantizing array: %s", array_name);
+
+ // Dequantize any buffer
+ if (array->buffer) {
+ if (array->data_type == ArrayDataType::kUint8) {
+ DequantizeBuffer<ArrayDataType::kUint8>(array);
+ } else if (array->data_type == ArrayDataType::kInt32) {
+ DequantizeBuffer<ArrayDataType::kInt32>(array);
+ } else {
+ LOG(FATAL) << "Unhandled data type";
+ }
+ CHECK(array->data_type == ArrayDataType::kFloat);
+ CHECK(array->buffer->type == ArrayDataType::kFloat);
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+ return true;
+ } else {
+ array->data_type = ArrayDataType::kFloat;
+ }
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+
+ if (array->buffer) {
+ return true;
+ }
+
+ auto* op_outputting_array = GetOpWithOutput(*model, array_name);
+ if (op_outputting_array) {
+ if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
+ return true;
+ }
+ }
+
+ // If there was no minmax info, we can return now. Indeed,
+ // the below only serves to create a FakeQuant node, but some arrays are
+ // quantized without MinMax (see the CHECK above) and that corresponds to
+ // places where a FakeQuant node is actually not wanted, because the
+ // quantization params are meant to be inferred in another way (e.g. bias
+ // vector for a Conv op, see their special-casing in quantize.cc).
+ if (!array->minmax) {
+ return true;
+ }
+
+ // Determine whether to insert a FakeQuant before or after
+ // this array.
+ bool must_insert_fakequant_before = false;
+ bool must_insert_fakequant_after = false;
+ if (IsInputArray(*model, array_name)) {
+ must_insert_fakequant_after = true;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (array_name == output_array) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (array_name == rnn_state.state_array()) {
+ must_insert_fakequant_after = true;
+ }
+ if (array_name == rnn_state.back_edge_source_array()) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
+
+ // Create and insert the FakeQuant node
+ auto* fakequant_op = new FakeQuantOperator;
+ model->operators.emplace(FindFirstOpWithInput(model, array_name),
+ fakequant_op);
+ const string& new_array_name = AvailableArrayName(*model, array_name);
+ auto& new_array = model->GetOrCreateArray(new_array_name);
+ new_array.data_type = ArrayDataType::kFloat;
+ new_array.copy_shape(array->shape());
+ new_array.GetOrCreateMinMax() = array->GetMinMax();
+ fakequant_op->minmax.reset(new MinMax);
+ *fakequant_op->minmax = array->GetMinMax();
+ if (must_insert_fakequant_before) {
+ for (const auto& op : model->operators) {
+ for (string& output : op->outputs) {
+ if (output == array_name) {
+ output = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {new_array_name};
+ fakequant_op->outputs = {array_name};
+ } else {
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == array_name) {
+ input = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {array_name};
+ fakequant_op->outputs = {new_array_name};
+ }
+ return true;
+}
+
+} // namespace
+
+bool Dequantize::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ auto* op = op_it->get();
+
+ if (op->type == OperatorType::kDequantize) {
+ auto& input_array = model->GetArray(op->inputs[0]);
+ if (input_array.data_type == ArrayDataType::kFloat) {
+ return false;
+ }
+ if (input_array.final_data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ input_array.data_type = ArrayDataType::kFloat;
+ input_array.quantization_params = nullptr;
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.data_type = ArrayDataType::kFloat;
+ output_array.quantization_params = nullptr;
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+
+ std::vector<string> arrays;
+ for (const string& input : op->inputs) {
+ arrays.push_back(input);
+ }
+ for (const string& output : op->outputs) {
+ arrays.push_back(output);
+ }
+ bool changed = false;
+ for (const string& array : arrays) {
+ changed |= DequantizeArray(array, this, model);
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
new file mode 100644
index 0000000000..fea360740f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ if (!fakequant_op->minmax) {
+ return false;
+ }
+
+ const auto& output_array = model->GetArray(fakequant_op->outputs[0]);
+ if (!output_array.minmax) {
+ return false;
+ }
+
+ // Drop min/max inputs
+ for (int i = 1; i < fakequant_op->inputs.size(); i++) {
+ if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[i]);
+ }
+ }
+ fakequant_op->inputs.resize(1);
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
new file mode 100644
index 0000000000..a3ed6663bc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->outputs.size() < 2) {
+ // Conv op does not have im2col.
+ return false;
+ }
+
+ // Drop the im2col array.
+ CHECK_EQ(conv_op->outputs.size(), 2);
+ model->arrays.erase(conv_op->outputs[1]);
+ conv_op->outputs.resize(1);
+ AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
new file mode 100644
index 0000000000..badefeca88
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ProcessLinearOperator(Model* model, Operator* op) {
+ if (op->inputs.size() >= 3) {
+ return false;
+ }
+ const string& output_name = op->outputs[0];
+ const string& bias_name = AvailableArrayName(*model, output_name + "_bias");
+ op->inputs.push_back(bias_name);
+ DCHECK_EQ(op->inputs.size(), 3);
+ auto& bias_array = model->GetOrCreateArray(bias_name);
+ bias_array.data_type = ArrayDataType::kFloat;
+
+ return true;
+}
+} // namespace
+
+bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
+ auto* op = model->operators[op_index].get();
+ if (op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected) {
+ if (ProcessLinearOperator(model, op)) {
+ AddMessageF("Added bias vector to %s", LogName(*op));
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
new file mode 100644
index 0000000000..7a86510025
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+ const auto ac_it = model->operators.begin() + op_index;
+ const auto* ac_op = ac_it->get();
+
+ if (ac_op->type != OperatorType::kRelu6 &&
+ ac_op->type != OperatorType::kRelu1 &&
+ ac_op->type != OperatorType::kRelu) {
+ return false;
+ }
+
+ // Find the op producing the array passed to this activation function
+ Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
+
+ if (!op) return false;
+
+ if (CountTrueOutputs(*model, *op) > 1) {
+ AddMessageF(
+ "Not fusing activation function into %s because it has more than one "
+ " consumed output",
+ LogName(*op));
+ return false;
+ }
+
+ CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
+
+ int count_ops_consuming_output = CountOpsWithInput(*model, ac_op->inputs[0]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not fusing activation function into %s because it is consumed by more "
+ "than 1 other operator",
+ LogName(*op));
+ return false;
+ }
+
+ if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not fusing activation function into %s because it already has a fused "
+ "activation function",
+ LogName(*op));
+ return false;
+ }
+
+ // TODO(dkalenichenko): Great many ops don't support activation function
+ // fusing. Switch to the whilelist approach instead.
+ if (op->type == OperatorType::kConcatenation ||
+ op->type == OperatorType::kSlice) {
+ AddMessageF(
+ "Not fusing activation function because the %s op doesn't support it",
+ LogName(*op));
+ return false;
+ }
+
+ AddMessageF("Fusing activation function %s into the preceding %s",
+ LogName(*ac_op), LogName(*op));
+ if (ac_op->type == OperatorType::kRelu6) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu6;
+ } else if (ac_op->type == OperatorType::kRelu1) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu1;
+ } else if (ac_op->type == OperatorType::kRelu) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu;
+ } else {
+ LOG(FATAL) << "Unhandled activation function type";
+ }
+ model->arrays.erase(ac_op->inputs[0]);
+ op->outputs[0] = ac_op->outputs[0];
+ model->operators.erase(ac_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
new file mode 100644
index 0000000000..4619d8bbee
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -0,0 +1,300 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op,
+ const Operator* add_or_sub_op,
+ int index_of_constant_input) {
+ CHECK(add_or_sub_op->type == OperatorType::kAdd ||
+ add_or_sub_op->type == OperatorType::kSub);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a subtraction, the constant input should be the right hand
+ // side.
+ // This should have been checked before this point.
+ CHECK(add_or_sub_op->type != OperatorType::kSub ||
+ index_of_constant_input == 1);
+ if (following_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ const auto& weights = model->GetArray(following_op->inputs[1]);
+ auto& bias = model->GetArray(following_op->inputs[2]);
+ bias.minmax = nullptr;
+ const auto& operand =
+ model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
+ // We're only supporting the case of a scalar operand. Should have
+ // been checked earlier.
+ CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
+
+ const float scalar_operand =
+ operand.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // At this point we reduce the case of subtraction to that of addition
+ // by negating the operand.
+ float add_scalar_operand = 0.f;
+ if (add_or_sub_op->type == OperatorType::kAdd) {
+ add_scalar_operand = scalar_operand;
+ } else if (add_or_sub_op->type == OperatorType::kSub &&
+ index_of_constant_input == 1) {
+ add_scalar_operand = -scalar_operand;
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ // From here on we are fusing an addition. add_or_sub_op->type does not
+ // matter anymore.
+
+ const Shape& weights_shape = weights.shape();
+ const Shape& bias_shape = bias.shape();
+ const auto& weights_buffer = weights.GetBuffer<ArrayDataType::kFloat>();
+ const float* const weights_data = weights_buffer.data.data();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+
+ if (following_op->type == OperatorType::kConv ||
+ following_op->type == OperatorType::kFullyConnected) {
+ const int output_depth = weights_shape.dims(0);
+ // TODO(b/62904716): Bias array should become 1-D when padding removed.
+ CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1));
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ for (int d = 0; d < output_depth; d++) {
+ float accumulation = 0;
+ for (int i = 0; i < weights_per_depth; i++) {
+ accumulation +=
+ add_scalar_operand * weights_data[d * weights_per_depth + i];
+ }
+ bias_data[d] += accumulation;
+ }
+ } else if (following_op->type == OperatorType::kDepthwiseConv) {
+ const int output_depth =
+ weights_shape.dims(weights_shape.dimensions_count() - 1);
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ for (int c = 0; c < output_depth; c++) {
+ float accumulation = 0;
+ for (int k = 0; k < weights_per_depth; k++) {
+ accumulation += add_scalar_operand * weights_data[k * output_depth + c];
+ }
+ bias_data[c] += accumulation;
+ }
+ } else {
+ LOG(FATAL) << "Should not get here.";
+ }
+}
+
+void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
+ const Operator* mul_or_div_op,
+ int index_of_constant_input) {
+ CHECK(mul_or_div_op->type == OperatorType::kMul ||
+ mul_or_div_op->type == OperatorType::kDiv);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a division, the constant input should be the right hand side.
+ // This should have been checked before this point.
+ CHECK(mul_or_div_op->type != OperatorType::kDiv ||
+ index_of_constant_input == 1);
+ const auto& weights_name = following_op->inputs[1];
+ const auto& bias_name = following_op->inputs[2];
+ auto& weights = model->GetArray(weights_name);
+ DropMinMax(model, weights_name);
+ DropMinMax(model, bias_name);
+ const auto& operand =
+ model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
+ // We're only supporting the case of a scalar operand. Should have
+ // been checked earlier.
+ CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
+
+ const float scalar_operand =
+ operand.GetBuffer<ArrayDataType::kFloat>().data[0];
+
+ float* weights_data =
+ weights.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ const int weights_size = RequiredBufferSizeForShape(weights.shape());
+ for (int i = 0; i < weights_size; i++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[i] *= scalar_operand;
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[i] /= scalar_operand;
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+}
+
+} // namespace
+
+bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // We only can fuse an binary when the two operands break down as follows:
+ // 1. One operand is the (variable) output of a typical affine (linear plus
+ // bias)
+ // op of a finite list of possible types: at the moment Conv,
+ // DepthwiseConv and
+ // FullyConnected are supported.
+ // 2. The other operand is a constant param array.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can fuse into a constant.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // For division, we can only fuse if the denominator is constant.
+ if (binary_op->type == OperatorType::kDiv) {
+ if (index_of_constant_input != 1) {
+ AddMessageF("Not fusing %s because the denominator is not constant",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ const auto& operand_shape =
+ model->GetArray(binary_op->inputs[index_of_constant_input]).shape();
+ for (const auto& dim : operand_shape.dims()) {
+ if (dim > 1) {
+ AddMessageF(
+ "Not fusing %s into the following affine op, because we only know "
+ "how to do so when the constant operand is a scalar",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ if (binary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF("Not fusing %s because it has a fused activation function",
+ LogName(*binary_op));
+ return false;
+ }
+
+ Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
+
+ if (!following_op) {
+ AddMessageF(
+ "Not fusing %s because it is not consumed by exactly one other op",
+ LogName(*binary_op));
+ return false;
+ }
+
+ if (following_op->type != OperatorType::kConv &&
+ following_op->type != OperatorType::kFullyConnected &&
+ following_op->type != OperatorType::kDepthwiseConv) {
+ AddMessageF(
+ "Not fusing %s because the following %s is not of one of the supported "
+ "types",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+
+ if (following_op->inputs.size() < 3) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not have a bias vector",
+ LogName(*following_op), LogName(*binary_op));
+ return false;
+ }
+
+ const auto& weights = model->GetArray(following_op->inputs[1]);
+ const auto& bias = model->GetArray(following_op->inputs[2]);
+ if (!weights.buffer || !bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the following %s has non-constant weights or "
+ "bias arrays",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+
+ // Try to fuse the binary params into the following op's params
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ if (following_op->type == OperatorType::kConv) {
+ if (static_cast<ConvOperator*>(following_op)->padding.type !=
+ PaddingType::kValid) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not use VALID padding",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+ }
+ if (following_op->type == OperatorType::kDepthwiseConv) {
+ if (static_cast<DepthwiseConvOperator*>(following_op)->padding.type !=
+ PaddingType::kValid) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not use VALID padding",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+ }
+ FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
+ index_of_constant_input);
+ } else if (binary_op->type == OperatorType::kMul ||
+ binary_op->type == OperatorType::kDiv) {
+ FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op,
+ index_of_constant_input);
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+
+ AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
+ LogName(*following_op));
+
+ model->arrays.erase(binary_op->outputs[0]);
+ following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
+ const auto& old_constant_param_name =
+ binary_op->inputs[index_of_constant_input];
+ CHECK(IsConstantParameterArray(*model, old_constant_param_name));
+ if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
+ model->arrays.erase(old_constant_param_name);
+ }
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
new file mode 100644
index 0000000000..8948653ec3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -0,0 +1,326 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
+ const Operator* add_or_sub_op,
+ int index_of_constant_input) {
+ CHECK(add_or_sub_op->type == OperatorType::kAdd ||
+ add_or_sub_op->type == OperatorType::kSub);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ if (preceding_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ auto& bias = model->GetArray(preceding_op->inputs[2]);
+ bias.minmax = nullptr;
+ const auto& operand =
+ model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
+
+ const Shape& bias_shape = bias.shape();
+ const Shape& operand_shape = operand.shape();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+ const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
+ const float* const operand_data = operand_buffer.data.data();
+
+ // TODO(b/62904716): Bias array should become 1-D when padding removed.
+ const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
+ CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
+
+ enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
+
+ const OpType optype = (add_or_sub_op->type == OperatorType::kAdd)
+ ? OpType::BiasPlusOperand
+ : (index_of_constant_input == 1)
+ ? OpType::BiasMinusOperand
+ : OpType::OperandMinusBias;
+
+ for (int i = 0; i < depth; i++) {
+ float& bias_val = bias_data[i];
+ const float operand_val = operand_data[i];
+ if (optype == OpType::BiasPlusOperand) {
+ bias_val += operand_val;
+ } else if (optype == OpType::BiasMinusOperand) {
+ bias_val -= operand_val;
+ } else if (optype == OpType::OperandMinusBias) {
+ bias_val = operand_val - bias_val;
+ } else {
+ LOG(FATAL) << "Should not get here.";
+ }
+ }
+}
+
+void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
+ const Operator* mul_or_div_op,
+ int index_of_constant_input) {
+ CHECK(mul_or_div_op->type == OperatorType::kMul ||
+ mul_or_div_op->type == OperatorType::kDiv);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a division, the constant input should be the right hand side.
+ // This should have been checked before this point.
+ CHECK(mul_or_div_op->type != OperatorType::kDiv ||
+ index_of_constant_input == 1);
+ if (preceding_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ const auto& weights_name = preceding_op->inputs[1];
+ const auto& bias_name = preceding_op->inputs[2];
+ auto& weights = model->GetArray(weights_name);
+ DropMinMax(model, weights_name);
+ auto& bias = model->GetArray(bias_name);
+ DropMinMax(model, bias_name);
+ const auto& operand =
+ model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
+
+ const Shape& weights_shape = weights.shape();
+ const Shape& bias_shape = bias.shape();
+ const Shape& operand_shape = operand.shape();
+ auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const weights_data = weights_buffer.data.data();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+ const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
+ const float* const operand_data = operand_buffer.data.data();
+
+ // We support broadcasting the operand along the depth dimension,
+ // when the operand's depth is 1.
+ int operand_channel_increment = 0;
+ if (operand_shape.dimensions_count() >= 1 &&
+ operand_shape.dims(operand_shape.dimensions_count() - 1) ==
+ bias_shape.dims(bias_shape.dimensions_count() - 1)) {
+ operand_channel_increment = 1;
+ } else if (operand_shape.dimensions_count() == 0 ||
+ operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
+ operand_channel_increment = 0;
+ } else {
+ LOG(FATAL) << "Operand shape mismatch.";
+ }
+
+ int output_depth;
+
+ if (preceding_op->type == OperatorType::kConv ||
+ preceding_op->type == OperatorType::kFullyConnected) {
+ output_depth = weights_shape.dims(0);
+ } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
+ output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1);
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ int operand_channel = 0;
+ for (int c = 0; c < output_depth; c++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ bias_data[c] *= operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ bias_data[c] /= operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ if (preceding_op->type == OperatorType::kConv ||
+ preceding_op->type == OperatorType::kFullyConnected) {
+ for (int i = 0; i < weights_per_depth; i++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[c * weights_per_depth + i] *=
+ operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[c * weights_per_depth + i] /=
+ operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+ } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
+ for (int k = 0; k < weights_per_depth; k++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[k * output_depth + c] *= operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[k * output_depth + c] /= operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ operand_channel += operand_channel_increment;
+ }
+}
+} // namespace
+
+bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ const auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // We only can fuse an binary when the two operands break down as follows:
+ // 1. One operand is the (variable) output of a typical affine (linear plus
+ // bias)
+ // op of a finite list of possible types: at the moment Conv,
+ // DepthwiseConv and
+ // FullyConnected are supported.
+ // 2. The other operand is a constant param array.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can fuse into a constant.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // For division, we can only fuse if the denominator is constant.
+ if (binary_op->type == OperatorType::kDiv) {
+ if (index_of_constant_input != 1) {
+ AddMessageF("Not fusing %s because the denominator is not constant",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ Operator* preceding_op =
+ GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]);
+ if (!preceding_op) {
+ AddMessageF("Not fusing %s because it is not the output of another op",
+ LogName(*binary_op));
+ return false;
+ }
+
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (preceding_op->outputs[0] == output_array) {
+ return false;
+ }
+ }
+
+ if (preceding_op->type != OperatorType::kConv &&
+ preceding_op->type != OperatorType::kFullyConnected &&
+ preceding_op->type != OperatorType::kDepthwiseConv) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s is not of one of the supported "
+ "types",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ if (preceding_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has a fused activation "
+ "function",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ if (preceding_op->inputs.size() < 3) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s does not have a bias vector",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ const auto& weights = model->GetArray(preceding_op->inputs[1]);
+ const auto& bias = model->GetArray(preceding_op->inputs[2]);
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ if (!bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has a non-constant bias "
+ "array",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+ } else {
+ if (!weights.buffer || !bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has non-constant weights or "
+ "bias arrays",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+ }
+
+ int count_ops_consuming_output =
+ CountOpsWithInput(*model, preceding_op->outputs[0]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not fusing %s because the output of the preceding %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
+ LogName(*preceding_op));
+
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op,
+ index_of_constant_input);
+ } else if (binary_op->type == OperatorType::kMul ||
+ binary_op->type == OperatorType::kDiv) {
+ FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op,
+ index_of_constant_input);
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+
+ model->arrays.erase(preceding_op->outputs[0]);
+ preceding_op->outputs[0] = binary_op->outputs[0];
+ preceding_op->fused_activation_function =
+ binary_op->fused_activation_function;
+ const auto& old_constant_param_name =
+ binary_op->inputs[index_of_constant_input];
+ CHECK(IsConstantParameterArray(*model, old_constant_param_name));
+ if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
+ model->arrays.erase(old_constant_param_name);
+ }
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
new file mode 100644
index 0000000000..323fec6cf8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void PrintModelStats(const string& label, const Model& model) {
+ int quantized_arrays = 0;
+ for (const auto& array : model.arrays) {
+ if (array.second->quantization_params) {
+ quantized_arrays++;
+ }
+ }
+ LOG(INFO) << label << ": " << model.operators.size() << " operators, "
+ << model.arrays.size() << " arrays (" << quantized_arrays
+ << " quantized)";
+}
+
+bool GraphTransformationsPass(int increment, Model* model,
+ const GraphTransformationsSet& transformations) {
+ CHECK(increment == 1 || increment == -1);
+ bool changed = false;
+ CHECK(!model->operators.empty());
+ int op_index = increment == 1 ? 0 : model->operators.size() - 1;
+ while (true) {
+ bool changed_now = false;
+ // Loop over all transformations at the current position in the graph.
+ for (const auto& transformation : transformations) {
+ CHECK(!changed_now);
+ CHECK(transformation->Messages().empty());
+ changed_now = transformation->Run(model, op_index);
+ if (changed_now) {
+ DumpGraphvizVideoFrame(*model);
+ CHECK(!model->operators.empty());
+ op_index = std::min<int>(op_index, model->operators.size() - 1);
+ // Uncomment for debugging
+ // CheckInvariants(*model);
+ }
+ const char* made_a_change_msg =
+ changed_now ? "made a change" : "did NOT make a change";
+ const int log_level =
+ changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged;
+ for (const string& message : transformation->Messages()) {
+ VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
+ << " at op_index=" << op_index << "/"
+ << model->operators.size() - 1 << ": " << message;
+ }
+ transformation->ClearMessages();
+ if (changed_now) {
+ break;
+ }
+ }
+ if (changed_now) {
+ changed = true;
+ } else {
+ const int op_index_last =
+ increment == 1 ? model->operators.size() - 1 : 0;
+ if (op_index == op_index_last) {
+ break;
+ }
+ op_index += increment;
+ }
+ }
+ return changed;
+}
+
+} // namespace
+
+void RunGraphTransformations(Model* model, const string& msg,
+ const GraphTransformationsSet& transformations) {
+ PrintModelStats(toco::port::StringF("Before %s", msg), *model);
+ int pass_index = 0;
+ while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
+ transformations)) {
+ pass_index++;
+ const auto& label =
+ toco::port::StringF("After %s pass %d", msg, pass_index);
+ PrintModelStats(label, *model);
+ CheckInvariants(*model);
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
new file mode 100644
index 0000000000..2cc24ff361
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+
+#include <cstddef>
+#include <initializer_list>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+
+namespace toco {
+
+class GraphTransformation {
+ public:
+ virtual bool Run(Model* model, std::size_t op_index) = 0;
+ virtual const char* Name() const = 0;
+ virtual ~GraphTransformation() {}
+ // Returns the list of messages that this graph transformation
+ // generated since ClearMessages() was called.
+ const std::vector<string>& Messages() const { return messages_; }
+ // Clears the list of messages; should be called after every
+ // run of this graph transformation.
+ void ClearMessages() { return messages_.clear(); }
+ // Adds a message; normally only called by the graph transformation
+ // itself during its run (this function could be protected).
+ template <typename... Args>
+ void AddMessageF(const char* format, const Args&... args) {
+ return messages_.push_back(toco::port::StringF(format, args...));
+ }
+
+ protected:
+ GraphTransformation() {}
+
+ // List of messages generated by this graph transformation.
+ std::vector<string> messages_;
+
+ private:
+ GraphTransformation(const GraphTransformation& other) = delete;
+ GraphTransformation(const GraphTransformation&& other) = delete;
+};
+
+class GraphTransformationsSet {
+ public:
+ // The choice of a container with fully-specified iteration order
+ // ensures that graph transformations are always run in the same order,
+ // which avoids having toco randomly fail or produce different results
+ // depending on the toolchain. Ideally success/results should be independent
+ // of the order in which graph transformations are run, but that's
+ // unfortunately not currently guaranteed to be the case.
+ using TransformationsContainer =
+ std::vector<std::unique_ptr<GraphTransformation>>;
+
+ GraphTransformationsSet() {}
+ GraphTransformationsSet(
+ const std::initializer_list<GraphTransformation*> transformations) {
+ for (GraphTransformation* t : transformations) {
+ Add(t);
+ }
+ }
+ void Add(GraphTransformation* transformation) {
+ const string& name = transformation->Name();
+ CHECK(!names_.count(name));
+ names_.insert(name);
+ transformations_.emplace_back(transformation);
+ }
+ TransformationsContainer::const_iterator begin() const {
+ return transformations_.begin();
+ }
+ TransformationsContainer::const_iterator end() const {
+ return transformations_.end();
+ }
+ bool empty() const { return transformations_.empty(); }
+
+ private:
+ GraphTransformationsSet(const GraphTransformationsSet& other) = delete;
+ GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
+ std::vector<std::unique_ptr<GraphTransformation>> transformations_;
+ // Names of transformations in the set. Only used to guard against dupes.
+ std::unordered_set<string> names_;
+};
+
+// Run the given list of graph transformations on the model.
+// The message is only for logging purposes.
+// The transformations is a rvalue reference, indicating that
+// nothing else will use these pointers. The user is supposed to
+// construct GraphTransformation objects by using 'new', pass us
+// the resulting raw pointers, and this RunGraphTransformations
+// takes care of delete'ing these pointers.
+void RunGraphTransformations(Model* model, const string& message,
+ const GraphTransformationsSet& transformations);
+
+#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
+ class GTName : public GraphTransformation { \
+ public: \
+ bool Run(Model* model, std::size_t op_index) override; \
+ const char* Name() const { return #GTName; } \
+ };
+
+// List of all graph transformations
+DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
+DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
+DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
+DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
+DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
+DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
+DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
+DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
+DECLARE_GRAPH_TRANSFORMATION(Quantize)
+DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
+DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp)
+DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
+DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
+DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
+DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
+DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
+DECLARE_GRAPH_TRANSFORMATION(Dequantize)
+
+class ResolveReshapeAttributes : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "ResolveReshapeAttributes"; }
+};
+
+class RemoveTrivialReshape : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "RemoveTrivialReshape"; }
+ bool treat_expand_dims_as_trivial() const {
+ return treat_expand_dims_as_trivial_;
+ }
+ void set_treat_expand_dims_as_trivial(bool val) {
+ treat_expand_dims_as_trivial_ = val;
+ }
+
+ private:
+ bool treat_expand_dims_as_trivial_ = false;
+};
+
+#undef DECLARE_GRAPH_TRANSFORMATION
+
+} // end namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
new file mode 100644
index 0000000000..d44b5dc7b0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -0,0 +1,229 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) {
+ if (op->outputs.size() != 2) {
+ return false;
+ }
+ auto& im2col_array = model->GetArray(op->outputs[1]);
+ if (im2col_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!im2col_array.minmax);
+ auto& im2col_minmax = im2col_array.GetOrCreateMinMax();
+ im2col_minmax.min = input_minmax.min;
+ im2col_minmax.max = input_minmax.max;
+ return true;
+}
+
+bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min >= 0. ? 0. : -1.;
+ output_minmax.max = input_minmax.max <= 0. ? 0. : 1.;
+ return true;
+}
+
+bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
+ // Do not early return if the output already has min/max:
+ // we may still need to adjust the inputs min/max.
+ bool has_minmax = false;
+ double overall_min = std::numeric_limits<double>::infinity();
+ double overall_max = -std::numeric_limits<double>::infinity();
+ for (const auto& input : op->inputs) {
+ if (model->GetArray(input).minmax) {
+ has_minmax = true;
+ const auto* minmax = model->GetArray(input).minmax.get();
+ if (minmax) {
+ overall_min = std::min(overall_min, minmax->min);
+ overall_max = std::max(overall_max, minmax->max);
+ }
+ }
+ }
+ auto& output = model->GetArray(op->outputs[0]);
+ if (output.minmax) {
+ has_minmax = true;
+ const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+ if (minmax) {
+ overall_min = std::min(overall_min, minmax->min);
+ overall_max = std::max(overall_max, minmax->max);
+ }
+ }
+ if (!has_minmax) {
+ return false;
+ }
+ MinMax overall_minmax;
+ overall_minmax.min = overall_min;
+ overall_minmax.max = overall_max;
+ bool changed = false;
+ for (const auto& input : op->inputs) {
+ auto& array = model->GetArray(input);
+ if (!array.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == array.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of array " << input << ", which is "
+ << "an input to " << LogName(*op) << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no "
+ "arithmetic.";
+ }
+ array.GetOrCreateMinMax() = overall_minmax;
+ }
+ if (!output.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == output.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of the output array of " << LogName(*op)
+ << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no arithmetic.";
+ }
+ output.GetOrCreateMinMax() = overall_minmax;
+
+ return changed;
+}
+
+// The output of average or max pooling is within the same range as its input.
+bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = std::min(input_minmax.min, 0.);
+ output_minmax.max = std::max(input_minmax.max, 0.);
+ return true;
+}
+
+bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min;
+ output_minmax.max = input_minmax.max;
+ return true;
+}
+
+bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
+ double max) {
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = min;
+ output_minmax.max = max;
+ return true;
+}
+} // namespace
+
+bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ bool changed = false;
+ switch (op->type) {
+ case OperatorType::kConv:
+ changed = HardcodeMinMaxForIm2colArray(model, op);
+ break;
+
+ case OperatorType::kL2Normalization:
+ changed = HardcodeMinMaxForL2Normalization(model, op);
+ break;
+
+ case OperatorType::kConcatenation:
+ changed = HardcodeMinMaxForConcatenation(model, op);
+ break;
+
+ case OperatorType::kAveragePool:
+ case OperatorType::kMaxPool:
+ changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
+ break;
+
+ case OperatorType::kTensorFlowReshape:
+ changed = HardcodeMinMaxForReshape(model, op);
+ break;
+
+ case OperatorType::kLogistic:
+ // We hardcode quantization_params to: zero_point=0, scale=1/256.
+ // This choice of minmax is the one that is equivalent to that.
+ changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
+ break;
+
+ case OperatorType::kSoftmax:
+ // We hardcode quantization_params to: zero_point=0, scale=1/256.
+ // This choice of minmax is the one that is equivalent to that.
+ changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
+ break;
+
+ default:
+ break;
+ }
+ if (changed) {
+ AddMessageF("Hardcoded min-max through %s", LogName(*op));
+ }
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
new file mode 100644
index 0000000000..01b75e37c6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -0,0 +1,170 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+} // namespace
+
+bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
+ const auto div_it = model->operators.begin() + op_index;
+ const auto* div_or_mul_op = div_it->get();
+ OperatorType expected_op_type_producing_div_or_mul_input;
+ if (div_or_mul_op->type == OperatorType::kDiv) {
+ expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
+ } else if (div_or_mul_op->type == OperatorType::kMul) {
+ expected_op_type_producing_div_or_mul_input =
+ OperatorType::kTensorFlowRsqrt;
+ } else {
+ return false;
+ }
+ CHECK_EQ(div_or_mul_op->inputs.size(), 2);
+ Operator* op_producing_div_or_mul_input[2] = {
+ GetOpWithOutput(*model, div_or_mul_op->inputs[0]),
+ GetOpWithOutput(*model, div_or_mul_op->inputs[1]),
+ };
+ if (!op_producing_div_or_mul_input[1] ||
+ op_producing_div_or_mul_input[1]->type !=
+ expected_op_type_producing_div_or_mul_input) {
+ return false;
+ }
+ Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
+ CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
+ Operator* op_producing_sqrt_or_rsqrt_input =
+ GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
+ if (!op_producing_sqrt_or_rsqrt_input) {
+ return false;
+ }
+
+ // There may be an Add or a Maximum here, adding or clamping to a "small"
+ // constant scalar.
+ // Reported bug: b/29395854
+ Operator* add_op = nullptr;
+ Operator* op_producing_add_input = nullptr;
+ if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
+ op_producing_sqrt_or_rsqrt_input->type ==
+ OperatorType::kTensorFlowMaximum) {
+ add_op = op_producing_sqrt_or_rsqrt_input;
+ bool add_can_be_removed = false;
+ CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
+ for (int i = 0; i < 2; i++) {
+ const auto& input_array =
+ model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]);
+ if (!input_array.buffer) {
+ continue;
+ }
+ if (input_array.buffer->type != ArrayDataType::kFloat) {
+ continue;
+ }
+ if (RequiredBufferSizeForShape(input_array.shape()) != 1) {
+ continue;
+ }
+ const auto& input_float_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ if (std::abs(input_float_data[0]) > 1e-3f) {
+ continue;
+ }
+ add_can_be_removed = true;
+ op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]);
+ break;
+ }
+ if (!add_can_be_removed) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph "
+ " because the operator producing the input to the square root, %s,"
+ ", does not match the expected pattern",
+ LogName(*op_producing_sqrt_or_rsqrt_input));
+ return false;
+ }
+ }
+
+ Operator* sum_op =
+ add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
+ if (sum_op->type != OperatorType::kTensorFlowSum) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: "
+ "expected Sum op, got %s",
+ LogName(*sum_op));
+ return false;
+ }
+
+ Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
+ if (square_op->type != OperatorType::kTensorFlowSquare) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: "
+ "expected Square op, got %s",
+ LogName(*square_op));
+ return false;
+ }
+
+ CHECK_EQ(square_op->inputs.size(), 1);
+
+ if (square_op->inputs[0] != div_or_mul_op->inputs[0]) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: %s does not "
+ "take the same input as the Mul/Div node",
+ LogName(*square_op));
+ return false;
+ }
+
+ // Create and emplace the new L2Normalization
+ auto* l2norm_op = new L2NormalizationOperator;
+ l2norm_op->inputs = {div_or_mul_op->inputs[0]};
+ l2norm_op->outputs = div_or_mul_op->outputs;
+ model->operators.emplace(div_it, l2norm_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
+
+ // Erase the subgraph that is now replaced by L2Normalization
+ model->operators.erase(FindOperator(model, square_op));
+ model->arrays.erase(sum_op->inputs[0]);
+ if (sum_op->inputs.size() > 1) {
+ model->arrays.erase(sum_op->inputs[1]);
+ }
+ model->operators.erase(FindOperator(model, sum_op));
+ if (add_op) {
+ model->arrays.erase(add_op->inputs[0]);
+ model->arrays.erase(add_op->inputs[1]);
+ model->operators.erase(FindOperator(model, add_op));
+ }
+ model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+ model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
+ model->arrays.erase(div_or_mul_op->inputs[1]);
+ model->operators.erase(FindOperator(model, div_or_mul_op));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
new file mode 100644
index 0000000000..1865416fc2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+} // namespace
+
+bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
+ const auto sqrt_it = model->operators.begin() + op_index;
+ const auto* sqrt_op = sqrt_it->get();
+ if (sqrt_op->type != OperatorType::kTensorFlowSqrt) {
+ return false;
+ }
+
+ CHECK_EQ(sqrt_op->inputs.size(), 1);
+ CHECK_EQ(sqrt_op->outputs.size(), 1);
+
+ const AveragePoolOperator* avpool_op;
+ const Operator* square_op;
+
+ Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]);
+ if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected AveragePool op, got %s",
+ LogName(*prev_to_sqrt_op));
+ return false;
+ }
+
+ avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op);
+ CHECK_EQ(avpool_op->inputs.size(), 1);
+
+ square_op = GetOpWithOutput(*model, avpool_op->inputs[0]);
+ CHECK_EQ(square_op->inputs.size(), 1);
+ if (square_op->type != OperatorType::kTensorFlowSquare) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected Square op, got %s",
+ LogName(*square_op));
+ return false;
+ }
+
+ // Create and emplace L2Pool node.
+ auto* l2pool_op = new L2PoolOperator;
+
+ l2pool_op->inputs = {square_op->inputs[0]};
+ l2pool_op->outputs = sqrt_op->outputs;
+
+ l2pool_op->padding.type = avpool_op->padding.type;
+ // Note that we do not setup avpool_op->padding.fixed here. This is done by
+ // the PropagateFixedSizes graph transformation.
+
+ l2pool_op->stride_height = avpool_op->stride_height;
+ l2pool_op->stride_width = avpool_op->stride_width;
+ l2pool_op->kheight = avpool_op->kheight;
+ l2pool_op->kwidth = avpool_op->kwidth;
+ model->operators.emplace(sqrt_it, l2pool_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
+
+ // Erase intermediate arrays, keeping input to square op.
+ model->arrays.erase(avpool_op->inputs[0]);
+ model->arrays.erase(sqrt_op->inputs[0]);
+
+ // Erase three operators being replaced.
+ model->operators.erase(FindOperator(model, square_op));
+ model->operators.erase(FindOperator(model, avpool_op));
+ model->operators.erase(FindOperator(model, sqrt_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
new file mode 100644
index 0000000000..082820fddc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -0,0 +1,396 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator& op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == &op) {
+ break;
+ }
+ }
+ return it;
+}
+
+bool GetStateArrayForBackEdge(const Model& model,
+ const string& back_edge_source_array,
+ string* state_array = nullptr) {
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (back_edge_source_array == rnn_state.back_edge_source_array()) {
+ // Found LSTM cell output
+ if (state_array) {
+ *state_array = rnn_state.state_array();
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns true if the given operator has exactly 1 input, and is connected to
+// the given op_type.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType op_type, Operator** connected_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 1) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (connected_op) {
+ *connected_op = x;
+ }
+
+ return true;
+}
+
+// Returns true if the given operator has exactly 2 inputs, which are connected
+// to the given op_types.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType a_op_type, Operator** a_op,
+ OperatorType b_op_type, Operator** b_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 2) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != a_op_type)) {
+ return false;
+ }
+
+ // Check if second input is disconnected/connected to an operator
+ Operator* y = GetOpWithOutput(model, op.inputs[1]);
+ if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
+ return false;
+ }
+ if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ return false;
+ }
+
+ // Check that second operator, if connected, is of correct type
+ if ((y != nullptr) && (y->type != b_op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (a_op != nullptr) {
+ *a_op = x;
+ }
+ if (b_op != nullptr) {
+ *b_op = y;
+ }
+ return true;
+}
+
+// Returns true if the given operator has exactly 3 inputs, which are connected
+// to the given op_types.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType a_op_type, Operator** a_op,
+ OperatorType b_op_type, Operator** b_op,
+ OperatorType c_op_type, Operator** c_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 3) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != a_op_type)) {
+ return false;
+ }
+
+ // Check if second input is disconnected/connected to an operator
+ Operator* y = GetOpWithOutput(model, op.inputs[1]);
+ if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
+ return false;
+ }
+ if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ return false;
+ }
+
+ // Check that second operator, if connected, is of correct type
+ if ((y != nullptr) && (y->type != b_op_type)) {
+ return false;
+ }
+
+ // Check if third input is disconnected/connected to an operator
+ Operator* z = GetOpWithOutput(model, op.inputs[2]);
+ if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
+ return false;
+ }
+ if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
+ return false;
+ }
+
+ // Check that third operator, if connected, is of correct type
+ if ((z != nullptr) && (z->type != c_op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (a_op != nullptr) {
+ *a_op = x;
+ }
+ if (b_op != nullptr) {
+ *b_op = y;
+ }
+ if (c_op != nullptr) {
+ *c_op = z;
+ }
+ return true;
+}
+
+absl::string_view FindLongestCommonPrefix(absl::string_view a,
+ absl::string_view b) {
+ if (a.empty() || b.empty()) return absl::string_view();
+
+ const char* pa = a.data();
+ const char* pb = b.data();
+ size_t count = 0;
+ const ssize_t limit = std::min(a.size(), b.size());
+ while (count < limit && *pa == *pb) {
+ ++pa;
+ ++pb;
+ ++count;
+ }
+
+ return absl::string_view(a.data(), count);
+}
+
+} // namespace
+
+bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
+ // This LSTM cell identification method is not invariant to commutation of
+ // commutative operator inputs. For example, if input[0] and input[1] of the
+ // final output multiplication were swapped, this method would not identify it
+ // as an LSTM cell. This is OK in most cases, because
+ // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way.
+
+ // Final output multiply
+ auto op_it = model->operators.begin() + op_index;
+ Operator* final_output_mul = op_it->get();
+ if (final_output_mul->type != OperatorType::kMul) {
+ return false;
+ }
+ Operator *state_output_tanh, *fc_output_sig;
+ if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
+ &state_output_tanh, OperatorType::kLogistic,
+ &fc_output_sig)) {
+ return false;
+ }
+
+ // State output TanH
+ // (We don't count an operator as ID'd until we verify it has the correct
+ // operator types feeding into it.)
+ Operator* state_combine_add;
+ if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
+ &state_combine_add)) {
+ return false;
+ }
+ string prev_state;
+ if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0],
+ &prev_state)) {
+ return false;
+ }
+
+ // State forget & remember addition
+ Operator *state_forget_mul, *state_remember_mul;
+ if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
+ &state_forget_mul, OperatorType::kMul,
+ &state_remember_mul)) {
+ return false;
+ }
+ if (state_forget_mul->inputs[0] != prev_state) {
+ return false;
+ }
+
+ // State forget gate
+ Operator* state_forget_sig;
+ if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
+ nullptr, OperatorType::kLogistic,
+ &state_forget_sig)) {
+ return false;
+ }
+
+ // State remember gate
+ Operator *state_remember_sig, *state_info_tanh;
+ if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
+ &state_remember_sig, OperatorType::kTanh,
+ &state_info_tanh)) {
+ return false;
+ }
+
+ // State remember "information" activation function
+ Operator* fc_output_split;
+ if (!MatchOperatorInputs(*state_info_tanh, *model,
+ OperatorType::kTensorFlowSplit, &fc_output_split)) {
+ return false;
+ }
+ // State remember gate activation function
+ Operator* tmp;
+ if (!MatchOperatorInputs(*state_remember_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // State forget gate activation function
+ if (!MatchOperatorInputs(*state_forget_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // Fully connected output activation function
+ if (!MatchOperatorInputs(*fc_output_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // Fully connected output split
+ Operator* fully_connected;
+ if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
+ nullptr, OperatorType::kFullyConnected,
+ &fully_connected)) {
+ return false;
+ }
+
+ // Fully connected op
+ Operator* concat_inputs;
+ if (!MatchOperatorInputs(*fully_connected, *model,
+ OperatorType::kConcatenation, &concat_inputs,
+ OperatorType::kNone, nullptr, OperatorType::kNone,
+ nullptr)) {
+ return false;
+ }
+
+ // Emplace a new LSTM cell operator
+ auto* lstm_cell_op = new LstmCellOperator;
+ lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
+ lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0];
+ lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] =
+ concat_inputs->inputs[1];
+ lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] =
+ fully_connected->inputs[1];
+ lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] =
+ fully_connected->inputs[2];
+ lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state;
+ lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
+ state_output_tanh->inputs[0];
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
+ final_output_mul->outputs[0];
+ model->operators.emplace(op_it, lstm_cell_op);
+ AddMessageF("Creating %s replacing equivalent subgraph",
+ LogName(*lstm_cell_op));
+
+ // Create temp arrays used internally during runtime.
+ const string base_name(FindLongestCommonPrefix(
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT],
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
+ const string& activ_temp_array_name =
+ AvailableArrayName(*model, base_name + "activ_temp");
+ model->GetOrCreateArray(activ_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name;
+ AddMessageF("Created temp outputs %s and %s on operator %s",
+ concat_temp_array_name, activ_temp_array_name,
+ LogName(*lstm_cell_op));
+
+ // Delete arrays and operators replaced by the LSTM cell operator. Order is
+ // important - DeleteArrayIfUnused() only succeeds if dependent operators
+ // have been removed first. Start at the output and work towards the input.
+ model->operators.erase(FindOperator(model, *final_output_mul));
+ DeleteArrayIfUnused(state_output_tanh->outputs[0], model);
+ DeleteArrayIfUnused(fc_output_sig->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_output_tanh));
+ model->operators.erase(FindOperator(model, *fc_output_sig));
+ model->operators.erase(FindOperator(model, *state_combine_add));
+ DeleteArrayIfUnused(state_forget_mul->outputs[0], model);
+ DeleteArrayIfUnused(state_remember_mul->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_forget_mul));
+ model->operators.erase(FindOperator(model, *state_remember_mul));
+ DeleteArrayIfUnused(state_forget_sig->outputs[0], model);
+ DeleteArrayIfUnused(state_info_tanh->outputs[0], model);
+ DeleteArrayIfUnused(state_remember_sig->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_forget_sig));
+ model->operators.erase(FindOperator(model, *state_info_tanh));
+ model->operators.erase(FindOperator(model, *state_remember_sig));
+ DeleteArrayIfUnused(fc_output_split->outputs[0], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[1], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[2], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[3], model);
+ string dims_array = fc_output_split->inputs[0];
+ model->operators.erase(FindOperator(model, *fc_output_split));
+ DeleteArrayIfUnused(dims_array, model);
+ DeleteArrayIfUnused(fully_connected->outputs[0], model);
+ model->operators.erase(FindOperator(model, *fully_connected));
+ DeleteArrayIfUnused(concat_inputs->outputs[0], model);
+ model->operators.erase(FindOperator(model, *concat_inputs));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
new file mode 100644
index 0000000000..cfc77024e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+
+bool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) {
+ const auto& op_array = model->GetArray(name);
+ if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat ||
+ RequiredBufferSizeForShape(op_array.shape()) != 1) {
+ return false;
+ }
+ const auto& op_data = op_array.GetBuffer<ArrayDataType::kFloat>().data;
+ return op_data[0] == val;
+}
+
+// Returns index of scalar input when there is exactly one scalar, -1 otherwise
+int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
+ float val) {
+ bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val);
+ bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val);
+ return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1;
+}
+} // namespace
+
+bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
+ const auto maximum_it = model->operators.begin() + op_index;
+ const auto* maximum_op = maximum_it->get();
+ if (maximum_op->type != OperatorType::kTensorFlowMaximum) {
+ return false;
+ }
+ CHECK_EQ(maximum_op->inputs.size(), 2);
+ if (maximum_op->outputs.size() != 1) {
+ return false;
+ }
+ int scalar_input_index =
+ GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f);
+ if (scalar_input_index == -1) {
+ return false;
+ }
+ const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]);
+ if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) {
+ return false;
+ }
+ if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) {
+ return false;
+ }
+ CHECK_EQ(minimum_op->inputs.size(), 2);
+
+ // Create and emplace Relu1 node
+ auto* relu1_op = new Relu1Operator;
+ relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]};
+ relu1_op->outputs = minimum_op->outputs;
+ model->operators.emplace(maximum_it, relu1_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
+
+ // Erase Maximum scalar input & operator
+ model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+ model->operators.erase(FindOperator(model, maximum_op));
+
+ // Erase Minimum inputs & operator
+ model->arrays.erase(minimum_op->inputs[0]);
+ model->arrays.erase(minimum_op->inputs[1]);
+ model->operators.erase(FindOperator(model, minimum_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
new file mode 100644
index 0000000000..d83603e9a2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// This inserts an operator whose output is a float array (name:
+// flags.input_array()). It has to wait for any existing operators that
+// generate this output to be removed by graph transformations. Note that there
+// may be more than one operator that takes the input_array as their input, and
+// that some of these may be removed by graph transformations.
+bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
+ GraphTransformation* transformation,
+ Model* model) {
+ // An operator with the required output may be a dequantize operator already
+ // created. Alternatively it may be an operator that needs to be removed
+ // because it is unused, in which case we wait for RemoveUnusedOp to do its
+ // work.
+ if (GetOpWithOutput(*model, input_name)) {
+ return false;
+ }
+
+ // We only apply for the first operator if there is more than one. This is
+ // not strictly necessary for ordering correctness, since we insert the
+ // dequant operator at the beginning of the op sequence, but it makes the
+ // insertion more predictable (eg forward vs backwards operator sweep).
+ if (CountOpsWithInput(*model, input_name) > 1) {
+ if (op != GetFirstOpWithInput(*model, input_name)) {
+ return false;
+ }
+ }
+
+ auto& input_array = model->GetArray(input_name);
+ if (input_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+
+ if (input_array.final_data_type == input_array.data_type ||
+ input_array.final_data_type == ArrayDataType::kNone) {
+ return false;
+ }
+
+ const auto& dequantized_input_name =
+ AvailableArrayName(*model, input_name + "_dequantized");
+ for (auto& other_op : model->operators) {
+ for (string& other_op_input : other_op->inputs) {
+ if (other_op_input == input_name) {
+ other_op_input = dequantized_input_name;
+ }
+ }
+ }
+
+ auto& dequantized_input_array =
+ model->GetOrCreateArray(dequantized_input_name);
+ auto* image_input_op = new DequantizeOperator;
+ image_input_op->inputs = {input_name};
+ image_input_op->outputs = {dequantized_input_name};
+ model->operators.emplace(model->operators.begin(), image_input_op);
+
+ CHECK(input_array.final_data_type == ArrayDataType::kUint8);
+ input_array.data_type = ArrayDataType::kUint8;
+ dequantized_input_array.data_type = ArrayDataType::kFloat;
+ const auto& input_minmax = input_array.GetMinMax();
+ auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
+ dequantized_input_minmax = input_minmax;
+ auto& input_qparams = input_array.GetOrCreateQuantizationParams();
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model->flags, input_minmax, &input_qparams);
+
+ transformation->AddMessageF(
+ "Created %s"
+ " to handle quantized input image data, taking over existing"
+ " mean_value and std_value flags. Cleared those flags.",
+ LogName(*image_input_op));
+
+ return true;
+}
+
+bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
+ // This is effectively a transformation applied to edges. We iterate over the
+ // specified node (op) and proceed for input edges.
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+ bool change_made = false;
+ for (auto& input : op->inputs) {
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == input) {
+ if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) {
+ change_made = true;
+ input_array.clear_mean_value();
+ input_array.clear_std_value();
+ }
+ }
+ }
+ }
+ return change_made;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
new file mode 100644
index 0000000000..1ff4e827aa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+ArrayDataType CommonDataTypeOfAllInputs(const Model& model,
+ const Operator& op) {
+ CHECK_GT(op.inputs.size(), 0);
+ const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type;
+ for (const auto& input : op.inputs) {
+ const auto& array = model.GetArray(input);
+ CHECK(array.data_type == data_type)
+ << " Unexpected: this operator has inputs with different data types.";
+ }
+ return data_type;
+}
+
+void SetDataTypeForAllOutputs(Model* model, Operator* op,
+ ArrayDataType data_type) {
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->data_type = data_type;
+ }
+}
+} // namespace
+
+bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If the data type of some input is unknown, we need to yield.
+ for (const auto& input : op->inputs) {
+ if (model->arrays[input]->data_type == ArrayDataType::kNone) {
+ return false;
+ }
+ }
+ // Record data types of output before processing, so we can see at the
+ // end if we changed anything, and return the correct boolean value.
+ std::unordered_map<string, ArrayDataType> old_output_data_types;
+ for (const auto& output : op->outputs) {
+ old_output_data_types[output] = model->arrays[output]->data_type;
+ }
+ // Do the actual output data types propagation.
+ if (op->type == OperatorType::kDequantize ||
+ op->type == OperatorType::kResizeBilinear) {
+ // These operators unconditionally produce float outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+ } else if (op->type == OperatorType::kTensorFlowLess ||
+ op->type == OperatorType::kTensorFlowLessEqual ||
+ op->type == OperatorType::kTensorFlowGreater ||
+ op->type == OperatorType::kTensorFlowGreaterEqual) {
+ // These operators unconditionally produce bool outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+ } else if (op->type == OperatorType::kTensorFlowShape) {
+ // These operators are assumed to produce int32 outputs.
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+ } else if (op->type == OperatorType::kAveragePool ||
+ op->type == OperatorType::kMaxPool ||
+ op->type == OperatorType::kL2Pool ||
+ op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected ||
+ op->type == OperatorType::kTensorFlowMax ||
+ op->type == OperatorType::kTensorFlowMin ||
+ op->type == OperatorType::kPad ||
+ op->type == OperatorType::kStridedSlice ||
+ op->type == OperatorType::kTensorFlowReshape ||
+ op->type == OperatorType::kSlice ||
+ op->type == OperatorType::kSqueeze ||
+ op->type == OperatorType::kTensorFlowSum ||
+ op->type == OperatorType::kTensorFlowSwitch ||
+ op->type == OperatorType::kTensorFlowTile ||
+ op->type == OperatorType::kTensorFlowAll ||
+ op->type == OperatorType::kReorderAxes ||
+ op->type == OperatorType::kTensorFlowConcatV2 ||
+ op->type == OperatorType::kFloor ||
+ op->type == OperatorType::kGather ||
+ op->type == OperatorType::kSpaceToBatchND ||
+ op->type == OperatorType::kBatchToSpaceND ||
+ op->type == OperatorType::kMean) {
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kTensorFlowSplit ||
+ op->type == OperatorType::kTensorFlowConcat) {
+ // These operators produce an output with the same type as their 2nd input
+ CHECK_GT(op->inputs.size(), 1);
+ const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kCast) {
+ // Data type of the Cast op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* cast_op = static_cast<CastOperator*>(op);
+ model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ } else if (op->type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+ if (unsupported_op->output_data_types.size() != op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
+ auto output = op->outputs[i];
+ auto data_type = unsupported_op->output_data_types[i];
+ model->arrays[output]->data_type = data_type;
+ }
+ } else {
+ // These operators produce an output with the same type as any of their
+ // inputs, which must always have the same type.
+ const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ }
+ // Return true if any output data type changed, false if none changed.
+ for (const auto& output : op->outputs) {
+ if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
new file mode 100644
index 0000000000..82a43bc2ce
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -0,0 +1,1129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
+ int kheight, int stride_width, int stride_height,
+ PaddingType padding_type, Shape* output_shape,
+ FixedPadding* fixed_padding) {
+ const int input_width = input_shape.dims(2);
+ const int input_height = input_shape.dims(1);
+ const int batch = input_shape.dims(0);
+
+ int output_height = 0;
+ int output_width = 0;
+ if (padding_type == PaddingType::kValid) {
+ output_height = (input_height + stride_height - kheight) / stride_height;
+ output_width = (input_width + stride_width - kwidth) / stride_width;
+ } else if (padding_type == PaddingType::kSame) {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ } else {
+ LOG(FATAL) << "Only supporting SAME or VALID padding";
+ }
+
+ fixed_padding->height =
+ ((output_height - 1) * stride_height + kheight - input_height) / 2;
+ fixed_padding->width =
+ ((output_width - 1) * stride_width + kwidth - input_width) / 2;
+
+ // Actually had to debug a situation where those were negative due to bad
+ // propagation of placeholder -1 sizes in TensorFlowReshape.
+ CHECK_GT(output_width, 0);
+ CHECK_GT(output_height, 0);
+ output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
+}
+
+void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
+ const Shape& input_shape2,
+ Array* output_array) {
+ const int size1 = RequiredBufferSizeForShape(input_shape1);
+ const int size2 = RequiredBufferSizeForShape(input_shape2);
+ if (size1 > size2) {
+ output_array->copy_shape(input_shape1);
+ } else if (size2 > size1) {
+ output_array->copy_shape(input_shape2);
+ } else {
+ CHECK_EQ(size1, size2);
+ const int dims1 = input_shape1.dimensions_count();
+ const int dims2 = input_shape2.dimensions_count();
+ if (dims1 >= dims2) {
+ output_array->copy_shape(input_shape1);
+ } else {
+ output_array->copy_shape(input_shape2);
+ }
+ }
+ CHECK(output_array->has_shape());
+}
+
+int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
+ const string& weights_name = op.inputs[1];
+ const auto& weights_shape = model.arrays.at(weights_name)->shape();
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kFullyConnected) {
+ return weights_shape.dims(0);
+ } else if (op.type == OperatorType::kDepthwiseConv) {
+ return weights_shape.dims(3);
+ } else {
+ LOG(FATAL) << "Unhandled operator type";
+ }
+}
+
+bool EnsureBiasVectorShape(Model* model, Operator* op) {
+ const string& weights_name = op->inputs[1];
+ const auto& weights_array = *model->arrays[weights_name];
+ // Yield until weights shape has been resolved.
+ if (!weights_array.has_shape()) {
+ return false;
+ }
+
+ if (op->inputs.size() < 3) {
+ return false;
+ }
+ auto& bias_array = *model->arrays[op->inputs[2]];
+ if (bias_array.has_shape()) {
+ return true;
+ }
+
+ const int output_depth = GetOutputDepthFromWeights(*model, *op);
+ bias_array.copy_shape(Shape({output_depth}));
+
+ auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ float_buffer.data.resize(output_depth, 0);
+
+ return true;
+}
+
+void ProcessConvOperator(Model* model, ConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 4);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ const int output_depth = weights_shape.dims(0);
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
+ op->stride_height, op->padding.type,
+ output_array.mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+ CHECK_EQ(output_array.shape().dimensions_count(), 4);
+
+ // Set im2col array dimensions if there is one.
+ if (op->outputs.size() == 2) {
+ const auto& output_shape = output_array.shape();
+ const int input_depth = weights_shape.dims(3);
+ auto& im2col_array = *model->arrays[op->outputs[1]];
+ im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
+ output_shape.dims(2),
+ input_depth * kheight * kwidth});
+ }
+}
+
+void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int input_depth = input_shape.dims(3);
+ const int output_depth = weights_shape.dims(3);
+ // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
+ // instead it has to be inferred from the weights dims. However, once we are
+ // here, weights dims have already been converted to our own internal format,
+ // where the multiplier is no longer readily apparent. So instead we get it
+ // as the quotient of output and input depths. We only want to do that when
+ // depth_multiplier had the zero value: any other value should be checked
+ // as done by the next if() below.
+ if (!op->depth_multiplier) {
+ op->depth_multiplier = output_depth / input_depth;
+ }
+ QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
+ << "input/output depths and depth_multiplier don't match";
+
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
+ op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int block_size = op->block_size;
+ CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
+ const int batch = input_shape.dims(0);
+ const int height = input_shape.dims(1);
+ const int width = input_shape.dims(2);
+ const int depth = input_shape.dims(3);
+ QCHECK_EQ(depth % (block_size * block_size), 0);
+
+ model->GetArray(output_name)
+ .copy_shape(Shape({batch, height * block_size, width * block_size,
+ depth / block_size / block_size}));
+}
+
+void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int block_size = op->block_size;
+ CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
+ const int batch = input_shape.dims(0);
+ const int height = input_shape.dims(1);
+ const int width = input_shape.dims(2);
+ const int depth = input_shape.dims(3);
+ QCHECK_EQ(width % block_size, 0);
+ QCHECK_EQ(height % block_size, 0);
+
+ model->GetArray(output_name)
+ .copy_shape(Shape({batch, height / block_size, width / block_size,
+ depth * block_size * block_size}));
+}
+
+void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_GE(input_shape.dimensions_count(), 1);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+
+ const int weights_output_depth = weights_shape.dims(0);
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+
+ const int input_overall_size = RequiredBufferSizeForShape(input_shape);
+ const int matmul_repeats = input_overall_size / weights_shape.dims(1);
+ CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
+}
+
+void ProcessTensorFlowReshapeOperator(Model* model,
+ TensorFlowReshapeOperator* op) {
+ auto& output_array = *model->arrays[op->outputs[0]];
+ // Bail if we already have output dims
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ const string& shape_name = op->inputs[1];
+ auto& shape_array = model->GetArray(shape_name);
+ // Yield until the shape is resolved as a constant array
+ if (!shape_array.buffer) {
+ return;
+ }
+ CHECK(shape_array.data_type == ArrayDataType::kInt32);
+ // shape_data is the raw array of ints describing the shape
+ // in the TensorFlow node. We intentionally make a copy here, rather than
+ // modify wildcards in-place below, because in some graphs, the same shape
+ // array with a wildcard may be referenced from multiple Reshape nodes, where
+ // the wildcard needs to resolved to distinct values.
+ std::vector<int32> shape_data =
+ shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ // The Reshape shape may have a wildcard dim, encoded as -1.
+ bool has_wildcard = false;
+ int wildcard_index = 0;
+ int product_non_wildcard_dims = 1;
+ for (int i = 0; i < shape_data.size(); i++) {
+ if (shape_data[i] == -1) {
+ CHECK(!has_wildcard);
+ has_wildcard = true;
+ wildcard_index = i;
+ } else {
+ product_non_wildcard_dims *= shape_data[i];
+ }
+ }
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+ if (has_wildcard) {
+ shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
+ }
+ auto& output_shape = *output_array.mutable_shape();
+ *output_shape.mutable_dims() = shape_data;
+ const int output_flat_size = RequiredBufferSizeForShape(output_shape);
+ CHECK_EQ(output_flat_size, input_flat_size);
+}
+
+void ProcessSimpleOperator(Model* model, Operator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const string& output_name = op->outputs[0];
+ auto& output_array = *model->arrays[output_name];
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ output_array.copy_shape(input_array.shape());
+}
+
+void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ const auto& input0_array = *model->arrays[op->inputs[0]];
+ const auto& input1_array = *model->arrays[op->inputs[1]];
+ // Yield until input dims have been resolved.
+ if (!input0_array.has_shape() || !input1_array.has_shape()) {
+ return;
+ }
+ const string& output_name = op->outputs[0];
+ auto& output_array = *model->arrays[output_name];
+ ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
+ &output_array);
+}
+
+void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
+ CHECK_LE(op->inputs.size(), 2);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ return;
+ }
+ if (op->inputs.size() == 2) {
+ // There is a reduction_indices input.
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& reduction_array = *model->arrays[op->inputs[1]];
+ if (!reduction_array.buffer) {
+ return;
+ }
+ if (!input_array.has_shape()) {
+ return;
+ }
+ auto& input_shape = input_array.shape();
+ CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
+ const auto& reduction_array_vals =
+ reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto& output_dims = *output_array.mutable_shape()->mutable_dims();
+ output_dims.clear();
+ for (int i = 0; i < input_shape.dimensions_count(); i++) {
+ bool is_reduction_dim = false;
+ for (int r : reduction_array_vals) {
+ if (i == r) {
+ is_reduction_dim = true;
+ }
+ }
+ if (!is_reduction_dim) {
+ output_dims.push_back(input_shape.dims(i));
+ }
+ }
+ } else {
+ // No reduction_indices means complete reduction to a single scalar.
+ output_array.copy_shape(Shape({}));
+ }
+}
+
+void ProcessSliceOperator(Model* model, SliceOperator* op) {
+ CHECK_EQ(op->inputs.size(), 3);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // Yield until the Slice params have been resolved.
+ if (op->begin.empty()) return;
+
+ // Yield until input dims have been resolved.
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) return;
+ const Shape& input_shape = input_array.shape();
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ CHECK_EQ(input_shape.dims().size(), op->size.size());
+ CHECK_EQ(op->begin.size(), op->size.size());
+
+ std::vector<int> output_dims;
+ for (int i = 0; i < op->begin.size(); ++i) {
+ int size = op->size[i];
+ if (size == -1) {
+ size = input_array.shape().dims(i) - op->begin[i];
+ }
+ output_dims.push_back(size);
+ }
+
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const string& output_name = op->outputs[0];
+ Shape* output_shape = model->GetArray(output_name).mutable_shape();
+ ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
+ output_shape);
+}
+
+void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
+ // Yield until input dims have been resolved.
+ for (const auto& input_name : op->inputs) {
+ auto& input_array = *model->arrays[input_name];
+ if (!input_array.has_shape()) {
+ return;
+ }
+ }
+ auto& output_array = model->GetArray(op->outputs[0]);
+ // Use 0 input as basis for output dimensions.
+ const auto& first_input_array = *model->arrays[op->inputs[0]];
+ output_array.copy_shape(first_input_array.shape());
+ // Determine the concat size, and enfore that all inputs have
+ // the same dimensions count.
+ int concat_size = 0;
+ for (const auto& input_name : op->inputs) {
+ auto& input_array = *model->arrays[input_name];
+ CHECK(input_array.has_shape());
+ if (input_array.shape().dimensions_count() == 0) {
+ continue;
+ }
+ CHECK_EQ(input_array.shape().dimensions_count(),
+ output_array.shape().dimensions_count());
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ CHECK_LT(op->concat_dim, input_dims.size());
+ concat_size += input_dims[op->concat_dim];
+ }
+ // Write out the concat_size on the output array shape.
+ auto& output_shape = *output_array.mutable_shape();
+ auto& output_dims = *output_shape.mutable_dims();
+ CHECK_LT(op->concat_dim, output_shape.dimensions_count());
+ output_dims[op->concat_dim] = concat_size;
+}
+
+void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ const string& input_name = op->inputs[1];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const Shape& input_shape = input_array.shape();
+
+ // This code is slightly suspect. The TensorFlow docs say that the axis
+ // selection defaults to 0, but we are splitting across the final axis.
+ const int input_dims_count = input_shape.dimensions_count();
+ const int input_depth = input_shape.dims(input_dims_count - 1);
+ CHECK_EQ(input_depth % op->num_split, 0);
+ const int split_depth = input_depth / op->num_split;
+
+ Shape output_shape = input_shape;
+ (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth;
+
+ CHECK_EQ(op->outputs.size(), op->num_split);
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->copy_shape(output_shape);
+ }
+}
+
+void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ if (input_shape.dimensions_count() < 4) {
+ LOG(FATAL) << "missing dimensions for " << input_name;
+ }
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ if (!model->arrays[op->inputs[0]]->has_shape() ||
+ !model->arrays[op->inputs[1]]->has_shape()) {
+ return;
+ }
+ const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+
+ const string& output_size_name = op->inputs[1];
+ const auto& output_size_array = *model->arrays[output_size_name];
+ CHECK(output_size_array.data_type == ArrayDataType::kInt32);
+ CHECK(output_size_array.has_shape());
+ const auto& output_size_shape = output_size_array.shape();
+ CHECK_EQ(output_size_shape.dimensions_count(), 1);
+ CHECK_EQ(output_size_shape.dims(0), 2);
+ std::vector<int32> output_shape =
+ output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
+ input_data_shape.dims(3)}));
+}
+
+void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
+ // I/O arrays should be allocated on creation of op.
+ QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
+ QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
+
+ const auto& input_array =
+ *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_GE(input_shape.dimensions_count(), 2);
+
+ const auto& prev_activ_array =
+ *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!prev_activ_array.has_shape()) {
+ return;
+ }
+ const auto& prev_activ_shape = prev_activ_array.shape();
+ CHECK_GE(prev_activ_shape.dimensions_count(), 2);
+
+ const auto& weights_array =
+ *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+
+ const auto& bias_array =
+ *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+ // Yield until bias dims have been resolved.
+ if (!bias_array.has_shape()) {
+ return;
+ }
+ const auto& bias_shape = bias_array.shape();
+ CHECK_GE(bias_shape.dimensions_count(), 1);
+
+ const auto& prev_state_array =
+ *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!prev_state_array.has_shape()) {
+ return;
+ }
+ const auto& prev_state_shape = prev_state_array.shape();
+ CHECK_GE(prev_state_shape.dimensions_count(), 2);
+
+ const int fc_output_depth = weights_shape.dims(0);
+ CHECK_EQ(fc_output_depth, bias_shape.dims(0));
+ CHECK_EQ(fc_output_depth % 4, 0);
+ const int depth = fc_output_depth / 4;
+
+ const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
+ const int fc_input_depth = weights_shape.dims(1);
+ CHECK_EQ(input_depth + depth, fc_input_depth);
+ Shape output_shape(input_shape);
+ (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
+
+ // Set output dimensions
+ model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
+ .copy_shape(output_shape);
+ model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
+ .copy_shape(output_shape);
+
+ Shape concat_temp_shape(input_shape);
+ (*concat_temp_shape
+ .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
+ fc_input_depth;
+ model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
+ .copy_shape(concat_temp_shape);
+
+ Shape activ_temp_shape(input_shape);
+ (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
+ fc_output_depth;
+ model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
+ .copy_shape(activ_temp_shape);
+}
+
+void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const auto input_height = input_shape.dims(1);
+ const auto input_width = input_shape.dims(2);
+
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& paddings_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array_shape = block_shape_array.shape();
+ const auto& paddings_array_shape = paddings_array.shape();
+ QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
+ QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
+
+ // We only support two dimensions.
+ QCHECK_EQ(block_shape_array_shape.dims(0), 2);
+ if (!block_shape_array.buffer) {
+ return;
+ }
+ QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
+ const auto& block_shape_data =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto block_height = block_shape_data[0];
+ auto block_width = block_shape_data[1];
+
+ QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions
+ QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension.
+ if (!paddings_array.buffer) {
+ return;
+ }
+ QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
+ const auto& paddings_data =
+ paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
+ int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
+ int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
+ QCHECK_EQ(height_with_paddings % block_height, 0);
+ QCHECK_EQ(width_with_paddings % block_width, 0);
+ int output_height = height_with_paddings / block_height;
+ int output_width = width_with_paddings / block_width;
+
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_shape.dims(0) * block_height * block_width, output_height,
+ output_width, input_shape.dims(3)}));
+}
+
+void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const auto input_height = input_shape.dims(1);
+ const auto input_width = input_shape.dims(2);
+
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array_shape = block_shape_array.shape();
+ const auto& crops_array_shape = crops_array.shape();
+ QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
+ QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
+
+ // We only support two dimensions.
+ QCHECK_EQ(block_shape_array_shape.dims(0), 2);
+ if (!block_shape_array.buffer) {
+ return;
+ }
+ QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
+ const auto& block_shape_data =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto block_height = block_shape_data[0];
+ auto block_width = block_shape_data[1];
+
+ QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions
+ QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension.
+ if (!crops_array.buffer) {
+ return;
+ }
+ QCHECK(crops_array.data_type == ArrayDataType::kInt32);
+ const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
+ // We don't support crops now.
+ QCHECK_EQ(crops_data[0], 0);
+ QCHECK_EQ(crops_data[1], 0);
+ QCHECK_EQ(crops_data[2], 0);
+ QCHECK_EQ(crops_data[3], 0);
+
+ QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
+
+ int output_height = input_height * block_height;
+ int output_width = input_width * block_width;
+
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_shape.dims(0) / (block_height * block_width), output_height,
+ output_width, input_shape.dims(3)}));
+}
+
+void ProcessGatherOperator(Model* model, GatherOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& indices_array = *model->arrays[op->inputs[1]];
+ auto& output_array = *model->arrays[op->outputs[0]];
+
+ // Bail if we already know the output shape.
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape() || !indices_array.has_shape()) {
+ return;
+ }
+
+ const auto& input_shape = input_array.shape();
+ const auto& indices_shape = indices_array.shape();
+ QCHECK_GE(input_shape.dimensions_count(), 1);
+ op->input_rank = input_shape.dimensions_count();
+
+ // We only support 1-D indices.
+ QCHECK_EQ(indices_shape.dimensions_count(), 1);
+
+ // Copy the input dimensions to the output except for dimension 0,
+ // where the dimension of indices_shape is used.
+ auto output_dims = output_array.mutable_shape()->mutable_dims();
+ output_dims->push_back(indices_shape.dims(0));
+ for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
+ output_dims->push_back(input_shape.dims(dim));
+ }
+}
+
+void ProcessPadOperator(Model* model, PadOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->left_padding.empty()) return;
+ CHECK_EQ(op->left_padding.size(), op->right_padding.size());
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->left_padding.size(), dims.size());
+
+ for (int i = 0; i < op->left_padding.size(); ++i) {
+ dims[i] += op->left_padding[i] + op->right_padding[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
+void ProcessMeanOperator(Model* model, MeanOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+ const std::vector<int>& indices = op->reduction_indices;
+ if (indices.empty()) return;
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (std::find(indices.begin(), indices.end(), i) == indices.end()) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ CHECK(!output_dims.empty());
+ CHECK_EQ(output_dims.size(), 2);
+
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->start_indices.empty()) return;
+ CHECK_EQ(op->start_indices.size(), op->stop_indices.size());
+ CHECK_EQ(op->start_indices.size(), op->strides.size());
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->start_indices.size(), dims.size());
+
+ for (int i = 0; i < op->start_indices.size(); ++i) {
+ const int mask = 1 << i;
+ const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
+ const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i]
+ : op->stop_indices[i];
+ dims[i] = (stop - start) / op->strides[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
+void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (input_dims[i] != 1 ||
+ (!op->squeeze_dims.empty() &&
+ std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) ==
+ op->squeeze_dims.end())) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
+ CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) return;
+
+ auto& weights_feature_array = *model->arrays[op->inputs[1]];
+ if (!weights_feature_array.has_shape()) return;
+
+ const auto& weights_time_array = *model->arrays[op->inputs[2]];
+ if (!weights_time_array.has_shape()) return;
+
+ const bool has_bias = (op->inputs.size() == 4);
+ if (has_bias) {
+ const auto& bias_array = *model->arrays[op->inputs[3]];
+ if (!bias_array.has_shape()) return;
+ }
+
+ const int batch_size = input_array.shape().dims()[0];
+ const int num_units = weights_feature_array.shape().dims()[0];
+ const int memory_size = weights_time_array.shape().dims()[1];
+
+ auto& state_array = model->GetArray(op->outputs[0]);
+ state_array.mutable_shape()->ReplaceDims(
+ {batch_size, memory_size * num_units});
+
+ auto& output_array = model->GetArray(op->outputs[1]);
+ output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
+}
+} // namespace
+
+bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ std::unordered_map<string, std::vector<int>> old_output_dims;
+ for (const auto& output : op->outputs) {
+ if (model->arrays[output]->has_shape()) {
+ old_output_dims[output] = model->arrays[output]->shape().dims();
+ }
+ }
+
+ switch (op->type) {
+ case OperatorType::kBatchNormalization:
+ case OperatorType::kL2Normalization:
+ case OperatorType::kDequantize:
+ case OperatorType::kRelu:
+ case OperatorType::kRelu1:
+ case OperatorType::kRelu6:
+ case OperatorType::kSoftmax:
+ case OperatorType::kLogistic:
+ case OperatorType::kTanh:
+ case OperatorType::kLocalResponseNormalization:
+ case OperatorType::kTensorFlowIdentity:
+ case OperatorType::kFakeQuant:
+ case OperatorType::kTensorFlowRsqrt:
+ case OperatorType::kTensorFlowSqrt:
+ case OperatorType::kTensorFlowSquare:
+ case OperatorType::kTensorFlowAll:
+ case OperatorType::kTensorFlowAssert:
+ case OperatorType::kCast:
+ case OperatorType::kFloor:
+ ProcessSimpleOperator(model, op);
+ break;
+ case OperatorType::kGather:
+ ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
+ break;
+
+ case OperatorType::kAdd:
+ case OperatorType::kSub:
+ case OperatorType::kMul:
+ case OperatorType::kDiv:
+ case OperatorType::kTensorFlowLess:
+ case OperatorType::kTensorFlowLessEqual:
+ case OperatorType::kTensorFlowGreater:
+ case OperatorType::kTensorFlowMaximum:
+ case OperatorType::kTensorFlowMinimum:
+ case OperatorType::kTensorFlowGreaterEqual:
+ ProcessSimpleBinaryOperator(model, op);
+ break;
+ case OperatorType::kConv:
+ ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ break;
+ case OperatorType::kDepthwiseConv:
+ ProcessDepthwiseConvOperator(model,
+ static_cast<DepthwiseConvOperator*>(op));
+ break;
+ case OperatorType::kDepthToSpace:
+ ProcessDepthToSpaceOperator(model,
+ static_cast<DepthToSpaceOperator*>(op));
+ break;
+ case OperatorType::kSpaceToDepth:
+ ProcessSpaceToDepthOperator(model,
+ static_cast<SpaceToDepthOperator*>(op));
+ break;
+ case OperatorType::kFullyConnected:
+ ProcessFullyConnectedOperator(model,
+ static_cast<FullyConnectedOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowReshape:
+ ProcessTensorFlowReshapeOperator(
+ model, static_cast<TensorFlowReshapeOperator*>(op));
+ break;
+ case OperatorType::kAveragePool:
+ ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
+ break;
+ case OperatorType::kMaxPool:
+ ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
+ break;
+ case OperatorType::kL2Pool:
+ ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowMin:
+ case OperatorType::kTensorFlowMax:
+ case OperatorType::kTensorFlowSum:
+ ProcessTensorFlowReductionOperator(model, op);
+ break;
+
+ case OperatorType::kSlice:
+ ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
+ break;
+
+ case OperatorType::kTensorFlowTile:
+ // We don't currently implement the propagation of fixed sizes through
+ // a TensorFlow Tile.
+ //
+ // Fortunately, we don't need to: so far, we have only dealt with Tile
+ // or Slice ops in subgraphs that are identified as L2Normalization.
+ // See IdentifyL2Normalization.
+ break;
+ case OperatorType::kTensorFlowSwitch:
+ // We can't know the sizes of the outputs until we have resolved the
+ // predicate, and once we have resolved the predicate, the whole
+ // Switch node will get resolved away.
+ // See ResolveTensorFlowSwitch.
+ break;
+ case OperatorType::kTensorFlowMerge:
+ // No need to bother resolving TensorFlow Merge ops: other graph
+ // transformations will remove them anyway.
+ // See ResolveTensorFlowMerge.
+ break;
+ case OperatorType::kTensorFlowSplit:
+ ProcessTensorFlowSplitOperator(model,
+ static_cast<TensorFlowSplitOperator*>(op));
+ break;
+ case OperatorType::kSqueeze:
+ ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kTensorFlowConcatV2:
+ // Unimplemented, hopefully another graph transformation will
+ // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
+ // will resolve this node to a DepthConcatenation, or else we have
+ // a more general non-depth concatenation that will hopefully be dropped,
+ // or else at the moment we will abort.
+ break;
+ case OperatorType::kTensorFlowShape:
+ // Unimplemented, hopefully another graph transformation will drop it or
+ // rewrite it.
+ break;
+ case OperatorType::kReorderAxes:
+ ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
+ break;
+ case OperatorType::kConcatenation:
+ ProcessConcatenationOperator(model,
+ static_cast<ConcatenationOperator*>(op));
+ break;
+ case OperatorType::kResizeBilinear:
+ ProcessResizeBilinearOperator(model,
+ static_cast<ResizeBilinearOperator*>(op));
+ break;
+ case OperatorType::kLstmCell:
+ ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowMatMul:
+ // MatMul operators are converted to FullyConnected, after which their
+ // shapes are propagated.
+ break;
+ case OperatorType::kSpaceToBatchND:
+ ProcessSpaceToBatchNDOperator(model,
+ static_cast<SpaceToBatchNDOperator*>(op));
+ break;
+ case OperatorType::kBatchToSpaceND:
+ ProcessBatchToSpaceNDOperator(model,
+ static_cast<BatchToSpaceNDOperator*>(op));
+ break;
+ case OperatorType::kPad:
+ ProcessPadOperator(model, static_cast<PadOperator*>(op));
+ break;
+ case OperatorType::kMean:
+ ProcessMeanOperator(model, static_cast<MeanOperator*>(op));
+ break;
+ case OperatorType::kStridedSlice:
+ ProcessStridedSliceOperator(model,
+ static_cast<StridedSliceOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowUnsupported:
+ break;
+ case OperatorType::kSvdf:
+ ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
+ break;
+ default:
+ // Unimplemented, another graph transformation should drop it.
+ LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
+ }
+
+ // Return true if any output dim changed, false if none changed.
+ // Assumption: no transformation clears an output shape, they only add shapes.
+ for (const auto& output : op->outputs) {
+ if (model->arrays[output]->has_shape() &&
+ (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
new file mode 100644
index 0000000000..5551755ea7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -0,0 +1,467 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool SupportsQuantization(const Operator& op) {
+ auto type = op.type;
+ if (type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
+ return unsupported->quantized;
+ }
+ return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv ||
+ type == OperatorType::kFullyConnected ||
+ type == OperatorType::kConcatenation ||
+ type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
+ type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
+ type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
+ type == OperatorType::kTensorFlowReshape ||
+ type == OperatorType::kMul || type == OperatorType::kSpaceToDepth ||
+ type == OperatorType::kDepthToSpace;
+}
+
+template <ArrayDataType A>
+std::unique_ptr<GenericBuffer> QuantizeBuffer(
+ const GenericBuffer& buffer,
+ const QuantizationParams& quantization_params) {
+ const auto inverse_scale = 1. / quantization_params.scale;
+ CHECK(buffer.type == ArrayDataType::kFloat);
+ const auto& float_buffer =
+ static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
+ auto* quantized_buffer = new Buffer<A>;
+ quantized_buffer->data.resize(float_buffer.data.size());
+ const auto qmin = static_cast<int32>(std::numeric_limits<DataType<A>>::min());
+ const auto qmax = static_cast<int32>(std::numeric_limits<DataType<A>>::max());
+ for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
+ const float src_val = float_buffer.data[i];
+ double scaled_val; // Astonishingly, using 'float' degrades accuracy just
+ // enough to make a few tests fail!
+ if (quantization_params.scale == 0) {
+ CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
+ << "so all its values should be 0.";
+ scaled_val = quantization_params.zero_point;
+ } else {
+ scaled_val = quantization_params.zero_point + inverse_scale * src_val;
+ }
+ const auto rounded_val = static_cast<int32>(std::round(scaled_val));
+ const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val));
+ quantized_buffer->data[i] = static_cast<DataType<A>>(clamped_val);
+ }
+ return std::unique_ptr<GenericBuffer>(quantized_buffer);
+}
+
+template <ArrayDataType A>
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name,
+ const QuantizationParams& quantization_params) {
+ auto& array = model->GetArray(name);
+ CHECK(array.data_type == ArrayDataType::kFloat);
+ CHECK(!array.quantization_params);
+ array.GetOrCreateQuantizationParams() = quantization_params;
+ if (array.buffer) {
+ array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ }
+ array.data_type = A;
+ transformation->AddMessageF("Quantized array %s", name);
+}
+
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params) {
+ switch (quantized_data_type) {
+ case ArrayDataType::kUint8:
+ return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt32:
+ return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
+ quantization_params);
+ default:
+ LOG(FATAL) << "Unhandled case.";
+ }
+}
+
+const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
+ auto& array = model->GetArray(array_name);
+ // Normally we should have a MinMax recorded on this Array,
+ // so we just use it.
+ if (array.minmax != nullptr) {
+ return *array.minmax;
+ }
+
+ // We don't have a MinMax. That's bad news: we need
+ // the graph to provide MinMax info for all arrays in order
+ // for inference to reproduce faithfully the same quantization
+ // error as the training process had.
+ //
+ // But we still want to support a fallback for constant arrays,
+ // just using the plain min and max computed from array elements.
+ // We should hopefully never rely on that in production, as that
+ // will not give very good accuracy as that typically won't be
+ // exactly what the training process used. But it will be useful
+ // to allow easily trying out quantization even if the graph
+ // lacks some minmax information.
+ if (array.buffer != nullptr) {
+ LOG(WARNING)
+ << "Constant array " << array_name
+ << " lacks MinMax information. To make up for that, we will now compute"
+ << " the MinMax from actual array elements. That will result in"
+ << " quantization parameters that probably do not match whichever "
+ "arithmetic"
+ << " was used during training, and thus will probably be a cause of "
+ "poor"
+ << " inference accuracy.";
+ CHECK(array.buffer->type == ArrayDataType::kFloat);
+ const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
+ // We always want [min, max] to contain 0.
+ float min = 0.f;
+ float max = 0.f;
+ for (auto val : data) {
+ min = std::min(min, val);
+ max = std::max(max, val);
+ }
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = min;
+ minmax.max = max;
+ return minmax;
+ }
+
+ LOG(FATAL) << "Array " << array_name
+ << " does not have MinMax information, "
+ "and is not a constant array. Cannot "
+ "proceed with quantization.";
+}
+
+bool ChooseQuantizationForOperatorInput(
+ GraphTransformation* transformation, Model* model, const Operator& op,
+ std::size_t input_index, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ const auto& input = op.inputs[input_index];
+ auto& array = model->GetArray(input);
+ if (array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kDepthwiseConv ||
+ op.type == OperatorType::kFullyConnected) {
+ if (input_index == 2) {
+ // Quantization of bias vector.
+ // We need both of the mandatory inputs (input activations and weights) to
+ // have
+ // been already quantized.
+ const auto& input_activations = model->GetArray(op.inputs[0]);
+ const auto& input_weights = model->GetArray(op.inputs[1]);
+ if (!input_activations.quantization_params ||
+ !input_weights.quantization_params) {
+ return false;
+ }
+ const auto input_activations_scale =
+ input_activations.quantization_params->scale;
+ const auto input_weights_scale = input_weights.quantization_params->scale;
+ quantization_params->scale =
+ input_activations_scale * input_weights_scale;
+ quantization_params->zero_point = 0;
+ *quantized_data_type = ArrayDataType::kInt32;
+ transformation->AddMessageF(
+ "Input array %s is a bias vector. Choosing quantization params "
+ "accordingly.",
+ input);
+ return true;
+ }
+ }
+
+ const MinMax& minmax = GetOrComputeMinMax(model, input);
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
+ quantization_params);
+ transformation->AddMessageF(
+ "For input array %s with min=%g"
+ ", max=%g"
+ ", chose to quantize as uint8 with zero_point=%d"
+ ", scale=%g",
+ input, minmax.min, minmax.max, quantization_params->zero_point,
+ quantization_params->scale);
+ *quantized_data_type = ArrayDataType::kUint8;
+ return true;
+}
+
+bool IsExactlyRepresentable(double real_value, ArrayDataType data_type,
+ const QuantizationParams& quantization_params) {
+ const double scaled_value =
+ quantization_params.zero_point + real_value / quantization_params.scale;
+ const double fractional_scaled_value =
+ scaled_value - std::round(scaled_value);
+ if (std::abs(fractional_scaled_value) > 1e-12) {
+ return false;
+ }
+ const double rounded_scaled_value = std::round(scaled_value);
+ if (data_type == ArrayDataType::kUint8) {
+ if (rounded_scaled_value < 0 || rounded_scaled_value > 255) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ChooseHardcodedQuantizationForOperatorOutput(
+ const Operator& op, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ if (op.type == OperatorType::kL2Normalization) {
+ // L2Normalization has range: [-1, 1].
+ // 0 should be exactly representable, as values will typically be centered
+ // around 0, with many values near 0.
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = 128;
+ quantization_params->scale = 1. / 128.;
+ CHECK(
+ IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
+ return true;
+ }
+ if ((op.type == OperatorType::kLogistic) ||
+ (op.type == OperatorType::kSoftmax)) {
+ // Logistic and Softmax have range: [0, 1].
+ //
+ // For Logistic, 0.5 should be exactly representable, as implementations
+ // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
+ // the glueing of the two halves of the graph will only be seamless if we
+ // are accurately representing logistic(0) == 0.5.
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = 0;
+ quantization_params->scale = 1. / 256.;
+ CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
+ *quantization_params));
+ return true;
+ }
+ return false;
+}
+
+bool ChooseQuantizationForOperatorOutput(
+ GraphTransformation* transformation, Model* model, const Operator& op,
+ std::size_t output_index, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ const auto& output = op.outputs[output_index];
+ auto& array = model->GetArray(output);
+ if (array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type,
+ quantization_params)) {
+ transformation->AddMessageF(
+ "Output array %s is produced by a %s operator. Choosing fixed "
+ "quantization params accordingly.",
+ output, OperatorTypeName(op.type));
+ return true;
+ }
+ if ((op.type == OperatorType::kDepthToSpace) ||
+ (op.type == OperatorType::kSpaceToDepth)) {
+ // DepthToSpace and SpaceToDepth should preserve the quantization parameters
+ // of the input array, as these are simple reshape operations.
+ const auto& input_quantization_params =
+ model->GetArray(op.inputs[0]).GetQuantizationParams();
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = input_quantization_params.zero_point;
+ quantization_params->scale = input_quantization_params.scale;
+
+ transformation->AddMessageF(
+ "Output array %s is produced by a %s operator. Copying quantization "
+ "params from input array.",
+ output, OperatorTypeName(op.type));
+ return true;
+ }
+ const MinMax& minmax = GetOrComputeMinMax(model, output);
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
+ quantization_params);
+ *quantized_data_type = ArrayDataType::kUint8;
+ transformation->AddMessageF(
+ "For output array %s with min=%g, max=%g"
+ ", chose to quantize as uint8 with zero_point=%d"
+ ", scale=%g",
+ output, minmax.min, minmax.max, quantization_params->zero_point,
+ quantization_params->scale);
+
+ return true;
+}
+} // namespace
+
+bool Quantize::Run(Model* model, std::size_t op_index) {
+ // Our general "quantization" graph transformation consists in replacing
+ // QuantizedInputArrays[] ->
+ // DequantizeOperators[] ->
+ // FloatInputArrays[] ->
+ // Operator ->
+ // FloatOutputArray
+ // by
+ // QuantizedInputArrays[] ->
+ // Operator ->
+ // QuantizedOutputArray ->
+ // DequantizeOperator ->
+ // FloatOutputArray
+ //
+ // In other words, this is pushing Dequantize operators to the right of
+ // other operators.
+ //
+
+ auto& op = *model->operators[op_index];
+ if (op.type == OperatorType::kDequantize ||
+ op.type == OperatorType::kFakeQuant) {
+ return false;
+ }
+
+ // Our assumption here is that the input arrays are already quantized -
+ // that is typically the case in models operating on an input bitmap
+ // image, and MakeInitialDequantizeOp should have already resolved
+ // the handling of the input image as an initial Dequantize op.
+ //
+ // Thus we are building around the assumption that the graph always starts
+ // with a quantized input array, and only after some Dequantize op do we have
+ // float arrays. The problem of quantizing the graph thus becomes a problem of
+ // pushing Dequantize ops to the right of other ops.
+ //
+ // Let us just guard this assumption by the following assertion:
+ for (const auto& input : op.inputs) {
+ if (IsInputArray(*model, input)) {
+ const auto& input_array = model->GetArray(input);
+ CHECK(input_array.quantization_params);
+ }
+ }
+ if (!SupportsQuantization(op)) {
+ LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
+ << HelpfulOperatorTypeName(op)
+ << " for which the quantized form is not yet implemented. "
+ "Sorry, and patches welcome (that's a relatively fun patch "
+ "to write, mostly providing the actual quantized arithmetic "
+ "code for this op).";
+ }
+
+ for (const auto& input : op.inputs) {
+ const auto& array = model->GetArray(input);
+ if (array.data_type == ArrayDataType::kFloat) {
+ if (!array.minmax && !array.buffer) {
+ LOG(ERROR) << "Can't quantize input array " << input
+ << " because it lacks min/max info";
+ return false;
+ }
+ const auto* other_op = GetOpWithOutput(*model, input);
+ if (other_op && other_op->type != OperatorType::kDequantize) {
+ AddMessageF(
+ "Not quantizing %s for now, because its input array %s is not "
+ "produced by a Dequantize op, "
+ "which means that we should yield and let other ops "
+ "get quantized first",
+ LogName(op), input);
+ return false;
+ }
+ }
+ }
+
+ bool changed = false;
+
+ // Quantize inputs, remove any Dequantize op on the inputs side
+ for (std::size_t input_index = 0; input_index < op.inputs.size();
+ input_index++) {
+ ArrayDataType quantized_data_type;
+ QuantizationParams quantization_params;
+ if (ChooseQuantizationForOperatorInput(this, model, op, input_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& input = op.inputs[input_index];
+ if (IsConstantParameterArray(*model, input)) {
+ QuantizeArray(this, model, input, quantized_data_type,
+ quantization_params);
+ } else {
+ auto dequantize_it = FindOpWithOutput(*model, input);
+ CHECK(dequantize_it != model->operators.end());
+ auto* dequantize_op = dequantize_it->get();
+ CHECK(dequantize_op->type == OperatorType::kDequantize);
+ op.inputs[input_index] = dequantize_op->inputs[0];
+ // Check if the output of that Dequantize op was not used by any
+ // other operator. We will then erase that Dequantize op.
+ if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
+ // If any of the model's output_arrays was pointing to the
+ // Dequantize op's output, let it point to the Dequantize op's
+ // input instead.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) {
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ }
+ }
+ model->arrays.erase(dequantize_op->outputs[0]);
+ model->operators.erase(dequantize_it);
+ }
+ }
+ }
+ }
+
+ // Quantize outputs, add Dequantize ops as needed on the outputs side
+ for (std::size_t output_index = 0; output_index < op.outputs.size();
+ output_index++) {
+ ArrayDataType quantized_data_type;
+ QuantizationParams quantization_params;
+ if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& output = op.outputs[output_index];
+ QuantizeArray(this, model, output, quantized_data_type,
+ quantization_params);
+ const auto& dequantized_output =
+ AvailableArrayName(*model, output + "_dequantized");
+ const auto& output_array = model->GetArray(output);
+ const auto& output_minmax = output_array.GetMinMax();
+ auto& dequantized_output_array =
+ model->GetOrCreateArray(dequantized_output);
+ dequantized_output_array.data_type = ArrayDataType::kFloat;
+ auto& dequantized_output_minmax =
+ dequantized_output_array.GetOrCreateMinMax();
+ dequantized_output_minmax.min = output_minmax.min;
+ dequantized_output_minmax.max = output_minmax.max;
+ for (const auto& other_op : model->operators) {
+ for (auto& other_op_input : other_op->inputs) {
+ if (other_op_input == output) {
+ other_op_input = dequantized_output;
+ }
+ }
+ }
+ auto* dequantize_op = new DequantizeOperator;
+ dequantize_op->inputs = {output};
+ dequantize_op->outputs = {dequantized_output};
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == output) {
+ model->flags.set_output_arrays(i, dequantized_output);
+ }
+ }
+ const auto op_it = FindOp(*model, &op);
+ model->operators.emplace(op_it + 1, dequantize_op);
+ }
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
new file mode 100644
index 0000000000..371ced388a
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model,
+ const MinMax& minmax, const string& array_name) {
+ auto& annotated_array = model->GetArray(array_name);
+ if (annotated_array.minmax) {
+ return false;
+ }
+ annotated_array.GetOrCreateMinMax() = minmax;
+ transformation->AddMessageF(
+ "Read min/max annotation for array %s: min=%g, max=%g", array_name,
+ minmax.min, minmax.max);
+ return true;
+}
+
+} // end namespace
+
+bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ bool changed = false;
+
+ if (!fakequant_op->minmax) {
+ CHECK_EQ(fakequant_op->inputs.size(), 3);
+ // We need to yield until the min and max parameters have been
+ // resolved to constant arrays.
+ for (int i = 1; i <= 2; i++) {
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) {
+ return false;
+ }
+ }
+
+ // Obtain the final min/max values
+ const auto& min_array = model->GetArray(fakequant_op->inputs[1]);
+ const auto& max_array = model->GetArray(fakequant_op->inputs[2]);
+ CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1);
+ CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1);
+ fakequant_op->minmax.reset(new MinMax);
+ MinMax& minmax = *fakequant_op->minmax;
+ minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // We always want [min, max] to contain 0.
+ minmax.min = std::min(minmax.min, 0.);
+ minmax.max = std::max(minmax.max, 0.);
+
+ // We won't use the input arrays that provided these min and max
+ // values, anymore. Delete them unless they are used by something
+ // else.
+ for (int i = 1; i <= 2; i++) {
+ if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[i]);
+ }
+ }
+ fakequant_op->inputs.resize(1);
+ changed = true;
+ }
+
+ // At this point, this FakeQuantOperator should have a MinMax
+ // attached to it, and should only have 1 input (it should not have
+ // 2nd and 3rd input arrays giving min and max anymore).
+ CHECK(fakequant_op->minmax);
+ CHECK_EQ(1, fakequant_op->inputs.size());
+
+ const MinMax& minmax = *fakequant_op->minmax;
+
+ // Record the MinMax info on the input and output arrays
+ changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]);
+ changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]);
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
new file mode 100644
index 0000000000..3992e7d1ef
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
+ const auto dequantize_it = model->operators.begin() + op_index;
+ const auto* dequantize_op = dequantize_it->get();
+ if (dequantize_op->type != OperatorType::kDequantize) {
+ return false;
+ }
+ const auto& output = dequantize_op->outputs[0];
+ // We can remove any dequantize op whose output is not consumed by
+ // any op. This is not necessarily equivalent to the output being
+ // one of the model's output arrays, as some intermediate array
+ // in the middle of the graph might be designated as an output
+ // array.
+ if (CountOpsWithInput(*model, output)) {
+ return false;
+ }
+
+ // If one of the model's output arrays was actually the Dequantize op's
+ // output, then we need to update it to point to the Dequantize op's input.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (output == model->flags.output_arrays(i)) {
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ }
+ }
+
+ // Remove the node and its output array.
+ AddMessageF("Removed final %s", LogName(*dequantize_op));
+ model->arrays.erase(output);
+ model->operators.erase(dequantize_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
new file mode 100644
index 0000000000..35a0c46532
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
+ const auto assert_it = model->operators.begin() + op_index;
+ const auto* assert_op = assert_it->get();
+ if (assert_op->type != OperatorType::kTensorFlowAssert) {
+ return false;
+ }
+
+ bool changed = false;
+ // Remove any other node's dependency on this assert node
+ for (const auto& op : model->operators) {
+ auto it = op->inputs.begin();
+ while (it != op->inputs.end()) {
+ if (*it == assert_op->outputs[0]) {
+ op->inputs.erase(it);
+ changed = true;
+ } else {
+ ++it;
+ }
+ }
+ }
+ CHECK(!CountOpsWithInput(*model, assert_op->outputs[0]));
+
+ if (changed) {
+ AddMessageF(
+ "Prepared for the removal of %s by removing any other op's dependency "
+ "on it",
+ LogName(*assert_op));
+ }
+
+ // That's it. We can stop here, no need to duplicate the work that
+ // RemoveUnusedOp will do removing this now-unused node.
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
new file mode 100644
index 0000000000..404269bbfd
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
+ const auto passthru_it = model->operators.begin() + op_index;
+ const auto* passthru_op = passthru_it->get();
+ if (passthru_op->type != OperatorType::kTensorFlowIdentity) {
+ return false;
+ }
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
new file mode 100644
index 0000000000..6add443f2d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <typename Scalar>
+bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
+ Scalar value) {
+ for (auto x : buffer_data) {
+ if (x != value) {
+ return false;
+ }
+ }
+ return true;
+}
+} // namespace
+
+// A binary operator is called trivial when exactly one of its operands is
+// a constant and is such that the binary operation is equivalent to
+// the identity operation on its other input.
+// For example, an Add operator is trivial if
+// one of its operands is constant 0, a Mul operator is trivial
+// if one of its operands is constant 1, etc.
+bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // This graph transformation is only concerned with the case
+ // when one input is constant and the other is not constant.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can resolve here.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // Now check if the constant operand makes this binary
+ // operator trivial.
+ const auto& constant_input_array =
+ *model->arrays[binary_op->inputs[index_of_constant_input]];
+ // For now, we only handle floats here.
+ if (constant_input_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ const auto& constant_input_float_data =
+ constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ bool is_trivial = false;
+ if (binary_op->type != OperatorType::kAdd) {
+ is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
+ } else if (binary_op->type != OperatorType::kSub) {
+ is_trivial = index_of_constant_input == 1 &&
+ AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
+ } else if (binary_op->type != OperatorType::kMul) {
+ is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
+ } else if (binary_op->type != OperatorType::kDiv) {
+ is_trivial = index_of_constant_input == 1 &&
+ AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
+ }
+
+ if (!is_trivial) {
+ return false;
+ }
+
+ // Now we know that this node is trivial, so we can remove it.
+ AddMessageF("Removing trivial %s", LogName(*binary_op));
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
new file mode 100644
index 0000000000..3ceb93d8ee
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
@@ -0,0 +1,40 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) {
+ const auto concat_it = model->operators.begin() + op_index;
+ auto* concat_op = concat_it->get();
+ if (concat_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ if (concat_op->inputs.size() != 1) {
+ return false;
+ }
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
new file mode 100644
index 0000000000..b603735704
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
+ // TensorFlow allows Concatenation nodes to have 0-D inputs,
+ // and they are then treated as empty i.e. omitted from concatenation,
+ // in violation of the notion that 0-D is equivalent to 1x1x1x1.
+ // Thus we have to drop these 0-D inputs from Concatenation nodes.
+ // Sometimes, there will remain only one non-trivial input, and
+ // the other graph transformation RemoveTrivialConcatenation will then drop
+ // it.
+ const auto concat_it = model->operators.begin() + op_index;
+ auto* concat_op = concat_it->get();
+ if (concat_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ std::vector<string> trivial_inputs;
+ std::vector<string> nontrivial_inputs;
+ for (const string& input : concat_op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ const bool is_trivial =
+ input_array.has_shape() && input_array.shape().dimensions_count() == 0;
+ if (is_trivial) {
+ trivial_inputs.push_back(input);
+ } else {
+ nontrivial_inputs.push_back(input);
+ }
+ }
+
+ if (trivial_inputs.empty()) {
+ return false;
+ }
+
+ // Drop trivial inputs.
+ for (const string& input : trivial_inputs) {
+ if (CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ concat_op->inputs = nontrivial_inputs;
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
new file mode 100644
index 0000000000..a0d1338298
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+// Reroute all edges involving a given discardable array to another
+// array instead. from_array is assumed to be discardable, and consequently
+// this only updates operator edges (since discardable arrays only
+// appear there, and not e.g. in model flags).
+void RerouteEdges(const string& from_array, const string& to_array,
+ Model* model) {
+ for (const auto& op : model->operators) {
+ for (auto& output : op->outputs) {
+ if (output == from_array) {
+ output = to_array;
+ }
+ }
+ for (auto& input : op->inputs) {
+ if (input == from_array) {
+ input = to_array;
+ }
+ }
+ }
+}
+
+} // end anonymous namespace
+
+bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
+ Model* model, std::size_t op_index) {
+ const auto passthru_it = model->operators.begin() + op_index;
+ auto* passthru_op = passthru_it->get();
+ CHECK_EQ(passthru_op->outputs.size(), 1);
+ CHECK_GE(passthru_op->inputs.size(), 1);
+ int count_nonconstant_input_arrays = 0;
+ // We call 'main input' the unique nonconstant input array if there is one,
+ // or else the 0-th input.
+ int main_input_array_index = 0;
+ for (int i = 0; i < passthru_op->inputs.size(); i++) {
+ if (!model->GetArray(passthru_op->inputs[i]).buffer) {
+ count_nonconstant_input_arrays++;
+ main_input_array_index = i;
+ }
+ }
+ CHECK_LE(count_nonconstant_input_arrays, 1);
+
+ const string main_input_name = passthru_op->inputs[main_input_array_index];
+ const string output_name = passthru_op->outputs[0];
+ if (IsDiscardableArray(*model, output_name)) {
+ transformation->AddMessageF(
+ "Removing %s, keeping its non-constant input array",
+ LogName(*passthru_op));
+ model->arrays.erase(output_name);
+ for (const string& input : passthru_op->inputs) {
+ if (IsDiscardableArray(*model, input) && input != main_input_name &&
+ CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ RerouteEdges(output_name, main_input_name, model);
+ } else if (IsDiscardableArray(*model, main_input_name)) {
+ transformation->AddMessageF("Removing %s, keeping its output array",
+ LogName(*passthru_op));
+ for (const string& input : passthru_op->inputs) {
+ if (IsDiscardableArray(*model, input) &&
+ (input == main_input_name || CountOpsWithInput(*model, input) == 1)) {
+ model->arrays.erase(input);
+ }
+ }
+ RerouteEdges(main_input_name, output_name, model);
+ } else {
+ transformation->AddMessageF(
+ "Cannot remove %s, neither its nonconstant input nor its output may be "
+ "discarded",
+ LogName(*passthru_op));
+ return false;
+ }
+
+ // Remove the pass-through node.
+ model->operators.erase(passthru_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
new file mode 100644
index 0000000000..b72c85c0e5
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+// A "passthrough op" is an op that satisfies the following conditions:
+// 1. It has at most one non-constant input (it may have other constant
+// inputs).
+// 2. It has exactly one output.
+// 3. It forwards exactly its single non-constant input to its single output.
+//
+// Examples include:
+// 1. TensorFlow Identity ops. (Have one input).
+// 2. TensorFlow Reshape ops when the input and output shapes agree.
+// 3. Any binary operator, one of whose two inputs is a constant and is the
+// neutral value for that operation. For example, a binary Add operator
+// where one of its inputs is a constant array filled with zeros.
+//
+// A passthrough op is "trivial" and can be removed when it is possible to
+// discard either its single non-constant input or output array, rerouting any
+// edge involving it to the other of these two arrays.
+//
+// It is only possible to discard such an array if it is not explicitly
+// designated as a global input/output array of the graph, e.g. the model's
+// input arrays, output arrays, and any array involved in a RNN back-edge
+// specified by the model.
+//
+// This function does not check that the given operator is a passthrough op:
+// that's the responsibility of the caller.
+// Given that it is a passthrough op, this function checks whether it is trivial
+// and then discards it and returns true, or, if it's not trivial (if neither
+// the input nor the output may be discarded), returns false.
+bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
+ Model* model, std::size_t op_index);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
new file mode 100644
index 0000000000..28f76c9d36
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
+ std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->fused_activation_function != FusedActivationFunctionType::kRelu &&
+ op->fused_activation_function != FusedActivationFunctionType::kRelu6) {
+ return false;
+ }
+ const auto& output_array = model->GetArray(op->outputs[0]);
+ if (!output_array.quantization_params) {
+ return false;
+ }
+ if (output_array.data_type != ArrayDataType::kUint8) {
+ return false;
+ }
+ const auto& quantization_params = output_array.GetQuantizationParams();
+
+ bool has_nontrivial_min_bound = false;
+ bool has_nontrivial_max_bound = false;
+
+ if (op->fused_activation_function == FusedActivationFunctionType::kRelu ||
+ op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
+ double lowest_representable_output =
+ (0. - quantization_params.zero_point) * quantization_params.scale;
+ if (lowest_representable_output < 0.) {
+ has_nontrivial_min_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the lowest representable output value %g"
+ " less than the clamp min bound.",
+ lowest_representable_output);
+ }
+ }
+ if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
+ double highest_representable_output =
+ (255. - quantization_params.zero_point) * quantization_params.scale;
+ if (highest_representable_output > 6.) {
+ has_nontrivial_max_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the highest representable output value %g"
+ " is greater than the clamp max bound.",
+ highest_representable_output);
+ }
+ }
+
+ if (has_nontrivial_min_bound || has_nontrivial_max_bound) {
+ return false;
+ }
+
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+ AddMessageF(
+ "Removing trivial quantized activation function on %s"
+ " because the output quantization parameters imply at least as tight"
+ " a clamp anyway.",
+ LogName(*op));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
new file mode 100644
index 0000000000..90f9381ec1
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsReshapeTrivial(const Model& model, const Operator& op,
+ RemoveTrivialReshape* transformation) {
+ CHECK(op.type == OperatorType::kTensorFlowReshape);
+
+ // One way in which a reshape can be trivial is if its
+ // output shape is == its input shape
+ const auto& input_array = model.GetArray(op.inputs[0]);
+ const auto& output_array = model.GetArray(op.outputs[0]);
+ if (input_array.has_shape() && output_array.has_shape()) {
+ if (transformation->treat_expand_dims_as_trivial() &&
+ ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
+ transformation->AddMessageF(
+ "%s is trivial because its input and output shapes are equal up to "
+ "extending "
+ "by 1's, and we are told to aggressively discard such Reshape ops.",
+ LogName(op));
+ return true;
+ }
+ if (input_array.shape().dims() == output_array.shape().dims()) {
+ transformation->AddMessageF(
+ "%s is trivial because its input and output shapes are equal",
+ LogName(op));
+ return true;
+ }
+ }
+
+ // Another way in which a reshape can be trivial is if its output
+ // is only consumed by another reshape.
+ if (CountOpsWithInput(model, op.outputs[0]) == 1) {
+ const auto* next_op = GetOpWithInput(model, op.outputs[0]);
+ if (next_op->type == OperatorType::kTensorFlowReshape) {
+ transformation->AddMessageF(
+ "%s is trivial because its output is only consumed by another "
+ "Reshape op",
+ LogName(op));
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
+ const auto reshape_it = model->operators.begin() + op_index;
+ auto* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+
+ if (!IsReshapeTrivial(*model, *reshape_op, this)) {
+ return false;
+ }
+
+ AddMessageF("Removing trivial %s", LogName(*reshape_op));
+
+ CHECK_EQ(reshape_op->inputs.size(), 2);
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
new file mode 100644
index 0000000000..1f1f1f6948
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+
+ // Bail if any output is used, and is not an input_array of
+ // the model. We allow specifying an arbitrary input_array,
+ // treating the part of the graph leading up to it as unused.
+ for (const auto& output : op->outputs) {
+ CHECK(model->arrays.count(output));
+ // If this output is provided as the model's input array,
+ // then we don't need this operator to produce its contents.
+ if (IsInputArray(*model, output)) {
+ continue;
+ }
+ // If this output is provided as a RNN's state array,
+ // then we don't need this operator to produce its contents.
+ // So far this case has only been encountered with TensorFlow
+ // Fill ops used to zero-initialize RNN states, which is
+ // redundant for us as we zero-initialize RNN states anyway.
+ bool found_output_as_rnn_state_array = false;
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.state_array()) {
+ CHECK(op->type == OperatorType::kTensorFlowUnsupported);
+ CHECK_EQ(static_cast<const TensorFlowUnsupportedOperator*>(op)
+ ->tensorflow_op,
+ "Fill");
+ found_output_as_rnn_state_array = true;
+ break;
+ }
+ }
+ if (found_output_as_rnn_state_array) {
+ continue;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (output == output_array) {
+ return false;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ }
+ if (CountOpsWithInput(*model, output)) {
+ return false;
+ }
+ }
+
+ if (op->unresolved_outputs) {
+ AddMessageF("Not discarding %s because it has unresolved outputs.",
+ LogName(*op));
+ return false;
+ }
+
+ AddMessageF("Discarding %s because none of its outputs is used.",
+ LogName(*op));
+
+ // At that point we know that none of the outputs is used, so we will
+ // definitely remove the node and all its outputs.
+
+ // Remove any input array that is not used by anything else,
+ // and that is not the output of some other operator.
+ for (const auto& input : op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1 &&
+ !GetOpWithOutput(*model, input)) {
+ model->arrays.erase(input);
+ }
+ }
+
+ // Remove the node and its now-unused output arrays.
+ for (const auto& output : op->outputs) {
+ // If the output array is the model's input array, don't remove that.
+ // That's the case when cropping a model at a given --input_array.
+ if (IsInputArray(*model, output)) {
+ continue;
+ }
+ // Likewise, if the output array is a RNN state array, don't remove that.
+ bool found_output_as_rnn_state_array = false;
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.state_array()) {
+ found_output_as_rnn_state_array = true;
+ break;
+ }
+ }
+ if (found_output_as_rnn_state_array) {
+ continue;
+ }
+ // Generic case: do delete this output array.
+ model->arrays.erase(output);
+ }
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
new file mode 100644
index 0000000000..3eb7fa3896
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
+ auto bn_it = model->operators.begin() + op_index;
+ if (bn_it->get()->type != OperatorType::kBatchNormalization) {
+ return false;
+ }
+ const auto* bn_op =
+ static_cast<const BatchNormalizationOperator*>(bn_it->get());
+
+ const auto& mean_array = model->GetArray(bn_op->inputs[1]);
+ const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
+ const auto& offset_array = model->GetArray(bn_op->inputs[3]);
+
+ CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
+ IsConstantParameterArray(*model, bn_op->inputs[2]) &&
+ IsConstantParameterArray(*model, bn_op->inputs[3]))
+ << "Batch normalization resolution requires that mean, multiplier and "
+ "offset arrays be constant.";
+
+ // We should only have *float* BatchNormalizations... let's guard this
+ // assumption by CHECK's.
+ CHECK(mean_array.data_type == ArrayDataType::kFloat);
+ CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
+ CHECK(offset_array.data_type == ArrayDataType::kFloat);
+
+ // Create the new Mul, Add operators
+ auto* mul_op = new MulOperator;
+ auto* add_op = new AddOperator;
+ const string mul_name =
+ AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
+ const string add_name =
+ AvailableArrayName(*model, bn_op->outputs[0] + "_add");
+ const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
+ const string add_param_name = AvailableArrayName(*model, add_name + "_param");
+ mul_op->inputs = {bn_op->inputs[0], mul_param_name};
+ mul_op->outputs = {mul_name};
+ add_op->inputs = {mul_name, add_param_name};
+ add_op->outputs = {bn_op->outputs[0]};
+ AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
+ LogName(*add_op));
+
+ // Create the intermediate activation array (output of mul, input of add)
+ auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
+ intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
+
+ // Insert the new operators in the graph
+ auto add_it = model->operators.emplace(bn_it, add_op);
+ auto mul_it = model->operators.emplace(add_it, mul_op);
+ // update invalidated iterators.
+ DCHECK_EQ(mul_it->get(), mul_op);
+ add_it = mul_it + 1;
+ DCHECK_EQ(add_it->get(), add_op);
+ bn_it = add_it + 1;
+ DCHECK_EQ(bn_it->get(), bn_op);
+
+ // Create the new param arrays
+ const auto& mean_shape = mean_array.shape();
+ const auto& multiplier_shape = multiplier_array.shape();
+ const auto& offset_shape = offset_array.shape();
+ CHECK(mean_shape.dims() == multiplier_shape.dims());
+ CHECK(mean_shape.dims() == offset_shape.dims());
+ const auto& param_shape = mean_shape;
+ const int buffer_size = RequiredBufferSizeForShape(param_shape);
+ auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
+ auto& add_param_array = model->GetOrCreateArray(add_param_name);
+ DropMinMax(model, mul_param_name);
+ DropMinMax(model, add_param_name);
+ mul_param_array.copy_shape(param_shape);
+ add_param_array.copy_shape(param_shape);
+ mul_param_array.data_type = ArrayDataType::kFloat;
+ add_param_array.data_type = ArrayDataType::kFloat;
+ auto& mul_float_data =
+ mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ auto& add_float_data =
+ add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ mul_float_data.resize(buffer_size);
+ add_float_data.resize(buffer_size);
+ const auto& mean_float_data =
+ mean_array.GetBuffer<ArrayDataType::kFloat>().data;
+ const auto& multiplier_float_data =
+ multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
+ const auto& offset_float_data =
+ offset_array.GetBuffer<ArrayDataType::kFloat>().data;
+
+ CHECK(mul_float_data.size() == buffer_size);
+ CHECK(add_float_data.size() == buffer_size);
+ CHECK(mean_float_data.size() == buffer_size);
+ CHECK(multiplier_float_data.size() == buffer_size);
+ CHECK(offset_float_data.size() == buffer_size);
+
+ for (int i = 0; i < buffer_size; i++) {
+ mul_float_data[i] = multiplier_float_data[i];
+ add_float_data[i] =
+ offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
+ }
+
+ // Remove the old param arrays
+ model->arrays.erase(bn_op->inputs[1]);
+ model->arrays.erase(bn_op->inputs[2]);
+ model->arrays.erase(bn_op->inputs[3]);
+
+ // Remove the old operator
+ DCHECK_EQ(bn_it->get(), bn_op);
+ model->operators.erase(bn_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
new file mode 100644
index 0000000000..53e1be7a05
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -0,0 +1,247 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
+ const std::vector<int>& b) {
+ DCHECK_EQ(a.size(), b.size());
+ const int size = a.size();
+ std::vector<bool> result(size);
+ for (int i = 0; i < size; i++) {
+ result[i] = a[i] > b[i];
+ }
+ return result;
+}
+
+void PairwiseVectorSelect(const std::vector<bool>& selector,
+ const std::vector<int>& input_a,
+ const std::vector<int>& input_b,
+ std::vector<int>* output_a,
+ std::vector<int>* output_b) {
+ DCHECK_EQ(input_a.size(), input_b.size());
+ DCHECK_EQ(output_a->size(), output_b->size());
+ DCHECK_EQ(input_a.size(), output_a->size());
+ DCHECK_EQ(selector.size(), input_a.size());
+ const int size = input_a.size();
+ for (int i = 0; i < size; i++) {
+ if (selector[i]) {
+ (*output_a)[i] = input_a[i];
+ (*output_b)[i] = input_b[i];
+ } else {
+ (*output_a)[i] = input_b[i];
+ (*output_b)[i] = input_a[i];
+ }
+ }
+}
+
+template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
+void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+ const Operator* binary_op) {
+ CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
+ CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
+ CHECK(binary_op->fused_activation_function ==
+ FusedActivationFunctionType::kNone);
+ const auto& input0_array = model->GetArray(binary_op->inputs[0]);
+ const auto& input1_array = model->GetArray(binary_op->inputs[1]);
+ const auto& output_name = binary_op->outputs[0];
+ auto& output_array = model->GetArray(output_name);
+ CHECK(input0_array.data_type == InputsDataType);
+ CHECK(input1_array.data_type == InputsDataType);
+ CHECK(output_array.data_type == OutputDataType);
+
+ // We have already tested above for existence of input buffers
+ // (synonymous to being a constant param).
+ CHECK(input0_array.buffer);
+ CHECK(input1_array.buffer);
+ // On the other hand, the output should not already have a buffer.
+ CHECK(!output_array.buffer);
+
+ const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
+ const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
+ // Create the buffer on the output array, effectively turning it into
+ // a constant parameter
+
+ const Shape& output_shape = output_array.shape();
+ auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
+ const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
+ output_data.resize(output_buffer_size);
+ const int dims_count = output_shape.dimensions_count();
+
+ // It will be convenient here to have copies of the operands shapes
+ // extended to match the number of dimensions of the output shape.
+ Shape input0_shape = input0_array.shape();
+ Shape input1_shape = input1_array.shape();
+ ExtendShape(&input0_shape, dims_count);
+ ExtendShape(&input1_shape, dims_count);
+ // Now we may still have operands of different sizes, which would indicate
+ // that we have to "broadcast" the smaller dimension. We do this using a
+ // a vector of Booleans indicating which input is the larger in each
+ // dimension.
+ CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
+ CHECK_EQ(input0_shape.dimensions_count(), dims_count);
+ const std::vector<bool> input0_larger =
+ VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
+
+ std::vector<int> big_sizes(dims_count);
+ std::vector<int> small_sizes(dims_count);
+ PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
+ &big_sizes, &small_sizes);
+
+ // The output should already be correctly sized to match the big dimensions.
+ for (int i = 0; i < dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), big_sizes[i]);
+ }
+
+ std::vector<int> input0_indices(dims_count);
+ std::vector<int> input1_indices(dims_count);
+ std::vector<int> modulo_indices(dims_count);
+
+ for (int k = 0; k < output_buffer_size; k++) {
+ const std::vector<int> output_indices = ReverseOffset(output_shape, k);
+ for (int i = 0; i < dims_count; i++) {
+ modulo_indices[i] = output_indices[i] % small_sizes[i];
+ }
+ PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
+ &input0_indices, &input1_indices);
+ const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
+ const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
+
+ DataType<OutputDataType> outval;
+ if (binary_op->type == OperatorType::kAdd) {
+ outval = val0 + val1;
+ } else if (binary_op->type == OperatorType::kMul) {
+ outval = val0 * val1;
+ } else if (binary_op->type == OperatorType::kSub) {
+ outval = val0 - val1;
+ } else if (binary_op->type == OperatorType::kDiv) {
+ outval = val0 / val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowMinimum) {
+ outval = std::min(val0, val1);
+ } else if (binary_op->type == OperatorType::kTensorFlowMaximum) {
+ outval = std::max(val0, val1);
+ } else if (binary_op->type == OperatorType::kTensorFlowLess) {
+ outval = val0 < val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) {
+ outval = val0 <= val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowGreater) {
+ outval = val0 > val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) {
+ outval = val0 >= val1;
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+ output_data[Offset(output_shape, output_indices)] = outval;
+ }
+}
+
+void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+ const Operator* binary_op) {
+ const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
+ const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
+ if (inputs_data_type == InputsDataType && \
+ output_data_type == OutputDataType) { \
+ EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
+ model, binary_op); \
+ return; \
+ }
+ TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
+ TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
+ LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
+ << "binary operator for these data types.";
+#undef TOCO_HANDLE_CASE
+}
+} // namespace
+
+bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ const auto* binary_op = binary_it->get();
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv &&
+ binary_op->type != OperatorType::kTensorFlowMinimum &&
+ binary_op->type != OperatorType::kTensorFlowMaximum &&
+ binary_op->type != OperatorType::kTensorFlowLess &&
+ binary_op->type != OperatorType::kTensorFlowLessEqual &&
+ binary_op->type != OperatorType::kTensorFlowGreater &&
+ binary_op->type != OperatorType::kTensorFlowGreaterEqual) {
+ return false;
+ }
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ const auto& input0_array = model->GetArray(binary_op->inputs[0]);
+ const auto& input1_array = model->GetArray(binary_op->inputs[1]);
+ // Check if both inputs are constant parameters.
+ if (!input0_array.buffer || !input1_array.buffer) {
+ return false;
+ }
+
+ auto& output_array = *model->arrays[binary_op->outputs[0]];
+ // Yield until the output array dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+
+ // At the moment we don't want to care about fused activation functions.
+ // The idea is that we should do the present constants-propagation before
+ // activation functions get fused.
+ if (binary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not resolving constant %s because it has a fused activation function",
+ LogName(*binary_op));
+ return false;
+ }
+
+ // Check that input data types agree.
+ CHECK(input0_array.data_type == input1_array.data_type);
+
+ // Do the actual constants propagation
+ EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
+
+ // Remove the binary operator and its inputs
+ if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
+ model->arrays.erase(binary_op->inputs[0]);
+ }
+ if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
+ model->arrays.erase(binary_op->inputs[1]);
+ }
+ AddMessageF("Resolved constant %s to the equivalent constant array",
+ LogName(*binary_op));
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
new file mode 100644
index 0000000000..0983c43849
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -0,0 +1,196 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+// Copies data from multiple source arrays to a destination array based on a
+// concatenation dimension. From each array in input_arrays, it copies chunk
+// sizes provided in array_copy_size vector (per array). It uses the buffer
+// in concatenated_array as destination buffer.
+template <ArrayDataType A, typename T>
+void CopyTensorSegments(const std::vector<Array*>& input_arrays,
+ const std::vector<int>& array_copy_size,
+ const int num_elements_concatenated_array,
+ Array* concatenated_array) {
+ for (Array* input_array : input_arrays) {
+ if (!input_array->buffer) {
+ return;
+ }
+ }
+
+ auto& concatenated_array_buffer =
+ concatenated_array->GetMutableBuffer<A>().data;
+ concatenated_array_buffer.resize(num_elements_concatenated_array);
+
+ // It does not matter which array to use to find the value for the total
+ // number of copy steps.
+ CHECK(!input_arrays.empty());
+ CHECK_NE(array_copy_size[0], 0);
+ const int total_copy_steps =
+ input_arrays[0]->GetBuffer<A>().data.size() / array_copy_size[0];
+
+ // Initialize the source pointers to point to beginning of the array buffers.
+ std::vector<const T*> src_ptr;
+ src_ptr.reserve(input_arrays.size());
+ for (Array* input_array : input_arrays) {
+ src_ptr.push_back(input_array->GetBuffer<A>().data.data());
+ }
+
+ // Copy the data from input_arrays to concatenated_array_buffer.
+ T* dest_ptr = concatenated_array_buffer.data();
+ for (int s = 0; s < total_copy_steps; s++) {
+ for (int i = 0; i < input_arrays.size(); i++) {
+ std::copy(src_ptr[i], src_ptr[i] + array_copy_size[i], dest_ptr);
+ src_ptr[i] += array_copy_size[i];
+ dest_ptr += array_copy_size[i];
+ }
+ }
+}
+
+// Receives a series of input arrays of type Array and an integer showing the
+// axis on which those arrays will be concatenated. It returns the concatenated
+// arrray.
+template <ArrayDataType A>
+void ConcatenateTensorBuffers(const std::vector<Array*>& input_arrays,
+ int concatenation_axis,
+ Array* concatenated_array) {
+ int num_elements_concatenated_array = 1;
+ for (int i = 0; i < concatenated_array->shape().dimensions_count(); i++) {
+ num_elements_concatenated_array *= concatenated_array->shape().dims()[i];
+ }
+ // Prepare the data needed for segmented copy from multiple source arrays to
+ // a destination array based on a oncatenation dimension.
+ std::vector<int> array_copy_size(input_arrays.size());
+ int count = 0;
+ for (Array* input_array : input_arrays) {
+ const Shape array_shape = input_array->shape();
+ array_copy_size[count] = 1;
+ for (int i = concatenation_axis; i < array_shape.dimensions_count(); i++) {
+ array_copy_size[count] *= array_shape.dims()[i];
+ }
+ count++;
+ }
+
+ // Do the actual data copy.
+ CopyTensorSegments<A, DataType<A>>(input_arrays, array_copy_size,
+ num_elements_concatenated_array,
+ concatenated_array);
+}
+
+// Sets the minimum and maximum values for the concatenated array. If it's
+// already set (e.g. because of previous pass in TOCO), it doesn't change it and
+// returns. Otherwise it uses the input arrays min and max values to compute the
+// concatenated array min and max.
+void SetMinMaxForConcatenedArray(const std::vector<Array*>& input_arrays,
+ Array* concatenated_array) {
+ CHECK(concatenated_array->data_type == ArrayDataType::kFloat);
+ // If the minmax is already set, use it
+ if (concatenated_array->minmax) return;
+
+ double concat_min = std::numeric_limits<double>::infinity();
+ double concat_max = -std::numeric_limits<double>::infinity();
+
+ for (Array* input_array : input_arrays) {
+ // If any of the input arrays minmax is not set, return.
+ // TODO(ghodrat): shall we add the logic to compute the minmax?
+ if (!input_array->minmax) return;
+ const MinMax& input_minmax = input_array->GetMinMax();
+ concat_min = std::min(concat_min, input_minmax.min);
+ concat_max = std::max(concat_max, input_minmax.max);
+ }
+ MinMax& minmax = concatenated_array->GetOrCreateMinMax();
+ minmax.min = concat_min;
+ minmax.max = concat_max;
+}
+
+} // namespace
+
+// Resolves the concatenation operator if all its inputs are constant arrays.
+bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
+ const auto concat_it = model->operators.begin() + op_index;
+ const auto* concat_base_op = concat_it->get();
+ if (concat_base_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ const auto* concat_op =
+ static_cast<const ConcatenationOperator*>(concat_base_op);
+
+ for (const string& input_name : concat_op->inputs) {
+ // We only expect constant unquantized arrays as input, otherwise we return.
+ // We also make sure the shapes of the input arrays are known and they are
+ // all discardable.
+ const Operator* input_op = GetOpWithOutput(*model, input_name);
+ if (input_op) return false;
+ if (!IsConstantParameterArray(*model, input_name)) return false;
+ if (!model->GetArray(input_name).has_shape()) return false;
+ if (model->GetArray(input_name).quantization_params) return false;
+ if (!IsDiscardableArray(*model, input_name)) return false;
+ }
+
+ const int concatenation_axis = concat_op->concat_dim;
+
+ CHECK_EQ(concat_op->outputs.size(), 1);
+ string concatenated_array_name = concat_op->outputs[0];
+ Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name);
+ std::vector<Array*> input_arrays;
+ for (const string& input_name : concat_op->inputs) {
+ input_arrays.push_back(&model->GetArray(input_name));
+ }
+
+ switch (concatenated_array.data_type) {
+ case ArrayDataType::kFloat:
+ ConcatenateTensorBuffers<ArrayDataType::kFloat>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ SetMinMaxForConcatenedArray(input_arrays, &concatenated_array);
+ break;
+ case ArrayDataType::kUint8:
+ ConcatenateTensorBuffers<ArrayDataType::kUint8>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ case ArrayDataType::kInt32:
+ ConcatenateTensorBuffers<ArrayDataType::kInt32>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ case ArrayDataType::kInt64:
+ ConcatenateTensorBuffers<ArrayDataType::kInt64>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ default:
+ LOG(FATAL) << "ArrayDataType not supported";
+ }
+
+ // Remove all the resolved arrays.
+ for (const string& input_name : concat_op->inputs) {
+ model->arrays.erase(input_name);
+ }
+
+ // Remove concatenate operator
+ model->operators.erase(concat_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
new file mode 100644
index 0000000000..244adcc4c4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ const auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+
+ const auto* fakequant_op =
+ static_cast<const FakeQuantOperator*>(fakequant_base_op);
+
+ // Yield until the fakequant MinMax has been resolved.
+ if (!fakequant_op->minmax) {
+ return false;
+ }
+
+ // This transformation only applies when the input array is constant.
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
+ return false;
+ }
+
+ const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
+ auto& output_array = model->GetArray(fakequant_op->outputs[0]);
+ CHECK(input_array.data_type == ArrayDataType::kFloat);
+ output_array.data_type = ArrayDataType::kFloat;
+ CHECK(!output_array.buffer);
+ const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ const int size = input_buffer.data.size();
+ output_buffer.data.resize(size);
+ QuantizationParams qparams;
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model->flags, *fakequant_op->minmax, &qparams);
+ for (int i = 0; i < size; i++) {
+ const double src_val = input_buffer.data[i];
+ const double unclamped_quantized_val =
+ std::round(qparams.zero_point + src_val / qparams.scale);
+ const double quantized_val =
+ std::min(255., std::max(0., unclamped_quantized_val));
+ const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
+ output_buffer.data[i] = dst_val;
+ }
+ if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[0]);
+ }
+ model->operators.erase(fakequant_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
new file mode 100644
index 0000000000..8cc6db1619
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) {
+ const auto tfshape_it = model->operators.begin() + op_index;
+ const auto* tfshape_base_op = tfshape_it->get();
+ if (tfshape_base_op->type != OperatorType::kTensorFlowShape) {
+ return false;
+ }
+
+ const auto* tfshape_op =
+ static_cast<const TensorFlowShapeOperator*>(tfshape_base_op);
+
+ const auto& input_array = model->GetArray(tfshape_op->inputs[0]);
+ auto& output_array = model->GetArray(tfshape_op->outputs[0]);
+
+ // Yield until the input array's shape has been resolved.
+ if (!input_array.has_shape()) {
+ return false;
+ }
+
+ // Create a buffer for the output array, making it a constant array, and
+ // copy the input shape into the output buffer.
+ CHECK(!output_array.buffer);
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ output_buffer.data = input_array.shape().dims();
+
+ // Erase the input array if no longer used
+ if (IsDiscardableArray(*model, tfshape_op->inputs[0]) &&
+ CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) {
+ model->arrays.erase(tfshape_op->inputs[0]);
+ }
+ model->operators.erase(tfshape_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
new file mode 100644
index 0000000000..bb9bda3c82
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -0,0 +1,175 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto unary_it = model->operators.begin() + op_index;
+ const auto* unary_op = unary_it->get();
+ // Test for unary ops of types that we know how to resolve
+ if (unary_op->type != OperatorType::kTensorFlowRsqrt &&
+ unary_op->type != OperatorType::kTensorFlowSqrt &&
+ unary_op->type != OperatorType::kTensorFlowSquare &&
+ unary_op->type != OperatorType::kTensorFlowSum &&
+ unary_op->type != OperatorType::kTensorFlowMin &&
+ unary_op->type != OperatorType::kTensorFlowMax &&
+ unary_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+ // Check if the input is a constant parameter.
+ if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
+ return false;
+ }
+
+ // if the unary op involves a tensor required by a rnn state, ignore it
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ if (unary_op->inputs[0] == rnn_state.state_array()) {
+ return false;
+ }
+ }
+
+ // At the moment we don't want to care about fused activation functions.
+ // The idea is that we should do the present constants-propagation before
+ // activation functions get fused.
+ if (unary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not resolving constant %s "
+ " because it has a fused activation function",
+ LogName(*unary_op));
+ return false;
+ }
+ const auto& input_array = model->GetArray(unary_op->inputs[0]);
+ // We have already tested above for existence of buffers (synonymous to being
+ // a constant param).
+ CHECK(input_array.buffer);
+ // At the moment we only support float buffers.
+ if (input_array.buffer->type != ArrayDataType::kFloat) {
+ return false;
+ }
+ const auto& input_float_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ // Create the float buffer on the output array, effectively turning it into
+ // a constant parameter
+ const auto& output_name = unary_op->outputs[0];
+ auto& output_array = model->GetArray(output_name);
+ // Yield until the output array dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+
+ int input_buffer_size = RequiredBufferSizeForShape(input_array.shape());
+ int output_buffer_size = RequiredBufferSizeForShape(output_array.shape());
+ const Shape& input_shape = input_array.shape();
+ const Shape& output_shape = output_array.shape();
+
+ auto& output_float_data =
+ output_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ output_float_data.resize(output_buffer_size);
+
+ const int output_dims_count = output_shape.dimensions_count();
+ if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ CHECK(input_buffer_size == output_buffer_size);
+ memcpy(output_float_data.data(), input_float_data.data(),
+ input_buffer_size * sizeof(input_float_data[0]));
+ } else if (unary_op->type == OperatorType::kTensorFlowSum) {
+ // At the moment only full reduction across all dimensions is supported.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float sum = 0.f;
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ sum += input_float_data[i];
+ }
+ output_float_data[0] = sum;
+ } else if (unary_op->type == OperatorType::kTensorFlowMin) {
+ // At the moment only full reduction across all dimensions is supported.
+ // TODO(starka): Output should not be padded.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float min = input_float_data[0];
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ min = std::min(min, input_float_data[i]);
+ }
+ output_float_data[0] = min;
+ } else if (unary_op->type == OperatorType::kTensorFlowMax) {
+ // At the moment only full reduction across all dimensions is supported.
+ // TODO(starka): Output should not be padded.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float max = input_float_data[0];
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ max = std::max(max, input_float_data[i]);
+ }
+ output_float_data[0] = max;
+ } else if (unary_op->type == OperatorType::kTensorFlowRsqrt ||
+ unary_op->type == OperatorType::kTensorFlowSqrt ||
+ unary_op->type == OperatorType::kTensorFlowSquare) {
+ // Element-wise ops. Should have perfectly matching sizes here.
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), input_shape.dims(i));
+ }
+
+ for (int i = 0; i < input_size; i++) {
+ const float val = input_float_data[i];
+ float outval = 0.f;
+ if (unary_op->type == OperatorType::kTensorFlowRsqrt) {
+ outval = 1.0f / std::sqrt(val);
+ } else if (unary_op->type == OperatorType::kTensorFlowSqrt) {
+ outval = std::sqrt(val);
+ } else if (unary_op->type == OperatorType::kTensorFlowSquare) {
+ outval = val * val;
+ } else {
+ LOG(FATAL) << "should not get here.";
+ }
+ output_float_data[i] = outval;
+ }
+ } else {
+ LOG(FATAL) << "should not get here.";
+ }
+ for (const auto& input : unary_op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ AddMessageF("Resolved constant %s to the equivalent constant array",
+ LogName(*unary_op));
+ model->operators.erase(unary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
new file mode 100644
index 0000000000..d25c773f19
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) {
+ auto* mean_op = model->operators[op_index].get();
+ if (mean_op->type != OperatorType::kMean) return false;
+ auto* op = static_cast<MeanOperator*>(mean_op);
+
+ if (!op->reduction_indices.empty()) return false;
+ if (op->inputs.size() != 2) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& indices_array = *model->arrays[op->inputs[1]];
+ if (!indices_array.has_shape()) return false;
+
+ op->reduction_indices = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // At the moment, we only support simultaneous reduction over width and
+ // height. This is mainly limited by the fact that currently, the runtime
+ // arrays are always 4-dimensional.
+ CHECK_EQ(op->reduction_indices.size(), 2);
+ CHECK((op->reduction_indices[0] == 1 && op->reduction_indices[1] == 2) ||
+ (op->reduction_indices[0] == 2 && op->reduction_indices[1] == 1));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
new file mode 100644
index 0000000000..d5f5869c62
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
+ const auto pad_it = model->operators.begin() + op_index;
+ auto* pad_op = pad_it->get();
+ if (pad_op->type != OperatorType::kPad) return false;
+
+ auto* op = static_cast<PadOperator*>(pad_op);
+ if (!op->left_padding.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 2);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& array = *model->arrays[op->inputs[1]];
+ if (!array.has_shape()) return false;
+
+ const std::vector<int>& dims = array.shape().dims();
+ CHECK_EQ(dims.size(), 2);
+
+ std::vector<int> buffer = array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ for (int i = 0; i < dims[0]; ++i) {
+ op->left_padding.push_back(buffer[i * 2]);
+ op->right_padding.push_back(buffer[i * 2 + 1]);
+ }
+
+ // TODO(dkalenichenko): Delete the extra input?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
new file mode 100644
index 0000000000..8fa7b83bed
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
+ auto reorder_it = model->operators.begin() + op_index;
+ auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
+ if (reorder_op->type != OperatorType::kReorderAxes) {
+ return false;
+ }
+ const auto& input_array_name = reorder_op->inputs[0];
+ const auto& output_array_name = reorder_op->outputs[0];
+ auto& input_array = model->GetArray(input_array_name);
+ auto& output_array = model->GetArray(output_array_name);
+ string constant_input_array_name = input_array_name;
+ if (!input_array.buffer) {
+ const auto* op_producing_input = GetOpWithOutput(*model, input_array_name);
+ if (op_producing_input &&
+ op_producing_input->type == OperatorType::kFakeQuant) {
+ constant_input_array_name = op_producing_input->inputs[0];
+ }
+ }
+ auto& constant_input_array = model->GetArray(constant_input_array_name);
+ if (!constant_input_array.buffer) {
+ return false;
+ }
+ // Yield until output dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // Reorder the input array dims and buffer data
+ CHECK(constant_input_array.buffer->type == ArrayDataType::kFloat);
+ CHECK(!output_array.buffer);
+ auto& input_data =
+ constant_input_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ std::vector<float> reordered_data;
+ reordered_data.resize(RequiredBufferSizeForShape(output_array.shape()));
+ const auto input_axes_order = reorder_op->input_axes_order;
+ const auto output_axes_order = reorder_op->output_axes_order;
+ // TODO(b/62904716) Shapes should be used directly.
+ Shape input_shape = constant_input_array.shape();
+ Shape output_shape = output_array.shape();
+ if (AxesCount(input_axes_order) == 2) {
+ UnextendShape(&input_shape, 2);
+ UnextendShape(&output_shape, 2);
+ }
+ ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
+ input_data.data(), reordered_data.data());
+ input_data = reordered_data;
+ input_array.copy_shape(output_array.shape());
+ constant_input_array.copy_shape(output_array.shape());
+
+ // Update the edges of the graph to point to the input array
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == output_array_name) {
+ input = input_array_name;
+ }
+ }
+ }
+
+ AddMessageF("Reordered axes for array %s", input_array_name);
+
+ // Remove the op and output array.
+ model->arrays.erase(output_array_name);
+ model->operators.erase(reorder_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
new file mode 100644
index 0000000000..bed2a85bd2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
+ const auto reshape_it = model->operators.begin() + op_index;
+ auto* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+
+ auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op);
+
+ if (!op->shape.empty()) return false;
+
+ if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
+ const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+ op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
+ }
+
+ if (op->shape.empty()) return false;
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
new file mode 100644
index 0000000000..1d0a2ec8f6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
+ const auto slice_it = model->operators.begin() + op_index;
+ auto* slice_op = slice_it->get();
+ if (slice_op->type != OperatorType::kSlice) return false;
+
+ auto* op = static_cast<SliceOperator*>(slice_op);
+ if (!op->begin.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 3);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+
+ const auto& begin_array = *model->arrays[op->inputs[1]];
+ if (!begin_array.has_shape()) return false;
+
+ const auto& size_array = *model->arrays[op->inputs[2]];
+ if (!size_array.has_shape()) return false;
+
+ op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // TODO(dkalenichenko): Delete the extra inputs?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
new file mode 100644
index 0000000000..5fc3b25bc1
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
+ const auto slice_it = model->operators.begin() + op_index;
+ auto* slice_op = slice_it->get();
+ if (slice_op->type != OperatorType::kStridedSlice) return false;
+
+ auto* op = static_cast<StridedSliceOperator*>(slice_op);
+ if (!op->start_indices.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 4);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+
+ const auto& start_array = *model->arrays[op->inputs[1]];
+ if (!start_array.has_shape()) return false;
+
+ const auto& stop_array = *model->arrays[op->inputs[2]];
+ if (!stop_array.has_shape()) return false;
+
+ const auto& stride_array = *model->arrays[op->inputs[3]];
+ if (!stride_array.has_shape()) return false;
+
+ op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // Only 4D arrays are supported for now.
+ CHECK_EQ(op->start_indices.size(), 4);
+ CHECK_EQ(op->stop_indices.size(), 4);
+ CHECK_EQ(op->strides.size(), 4);
+
+ // TODO(dkalenichenko): Delete the extra inputs?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
new file mode 100644
index 0000000000..b482f5cf51
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -0,0 +1,86 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
+ auto concat_it = model->operators.begin() + op_index;
+ const auto* tf_concat_op = concat_it->get();
+ if (tf_concat_op->type != OperatorType::kTensorFlowConcat &&
+ tf_concat_op->type != OperatorType::kTensorFlowConcatV2) {
+ return false;
+ }
+
+ CHECK_GE(tf_concat_op->inputs.size(), 2);
+ // TensorFlow Concat and ConcatV2 nodes only differ by the ordering
+ // of inputs: in Concat, the concat_dim is the first input, while in
+ // ConcatV2, it is the last input.
+ std::size_t concat_dim_pos = 0;
+ if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) {
+ concat_dim_pos = tf_concat_op->inputs.size() - 1;
+ }
+ const string concat_dim_name = tf_concat_op->inputs[concat_dim_pos];
+ std::vector<string> concat_input_names;
+ for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) {
+ if (i != concat_dim_pos) {
+ concat_input_names.push_back(tf_concat_op->inputs[i]);
+ }
+ }
+ // If the concat_dim array hasn't been resolved to a constant yet,
+ // we need to yield.
+ const auto& concat_dim_array = model->GetArray(concat_dim_name);
+ if (!concat_dim_array.buffer) {
+ AddMessageF("Waiting for the concat_dim of %s to be resolved to a constant",
+ LogName(*tf_concat_op));
+ return false;
+ }
+
+ CHECK(concat_dim_array.data_type == ArrayDataType::kInt32);
+ const auto& concat_dim_data =
+ concat_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(concat_dim_data.size(), 1);
+ const int concat_dim = concat_dim_data[0];
+
+ // Create the Concatenation op replacing the TensorFlowConcat op.
+ auto* concatenation_op = new ConcatenationOperator;
+ concatenation_op->concat_dim = concat_dim;
+ concatenation_op->inputs = concat_input_names;
+ concatenation_op->outputs = {tf_concat_op->outputs[0]};
+ auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op);
+ CHECK_EQ(depth_concat_it->get(), concatenation_op);
+ // Update invalidated iterator
+ concat_it = depth_concat_it + 1;
+ CHECK_EQ(concat_it->get(), tf_concat_op);
+
+ // Remove the concat_dim array if it is not used by anything else.
+ if (CountOpsWithInput(*model, concat_dim_name) == 1) {
+ model->arrays.erase(concat_dim_name);
+ }
+ // Remove the TensorFlowConcat op
+ model->operators.erase(concat_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
new file mode 100644
index 0000000000..bea7487051
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
+ auto matmul_it = model->operators.begin() + op_index;
+ if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) {
+ return false;
+ }
+ const auto* matmul_op = matmul_it->get();
+
+ // Find the op producing the array passed to this MatMul
+ auto previous_op_it = model->operators.begin();
+ bool found = false;
+ for (; previous_op_it != model->operators.end(); ++previous_op_it) {
+ for (const auto& output : (*previous_op_it)->outputs) {
+ if (output == matmul_op->inputs[0]) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ break;
+ }
+ }
+ Operator* previous_op = (found) ? previous_op_it->get() : nullptr;
+
+ // construct the new FullyConnectedOperator
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->outputs = matmul_op->outputs;
+
+ // insert the newly constructed FullyConnectedOperator
+ auto fc_it = model->operators.emplace(matmul_it, fc_op);
+
+ // refresh invalidated iterator
+ matmul_it = fc_it + 1;
+ DCHECK_EQ(matmul_it->get(), matmul_op);
+
+ // The way that TensorFlow encodes FullyConnected ops is as a pair
+ // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the
+ // MatMul
+ // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the
+ // input doesn't need reshaping, so we can't just match (Reshape, MatMul)
+ // pairs.
+ if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) {
+ AddMessageF("Combining %s and %s into %s", LogName(*previous_op),
+ LogName(*matmul_op), LogName(*fc_op));
+ const auto& previous_op_output = previous_op->outputs[0];
+ if (CountOpsWithInput(*model, previous_op_output) == 1) {
+ model->arrays.erase(previous_op_output);
+ }
+ CHECK_EQ(previous_op->inputs.size(), 2);
+ fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
+ // Only remove Reshape node if no other node uses its output.
+ if (CountOpsWithInput(*model, previous_op_output) == 1) {
+ const auto& previous_op_shape = previous_op->inputs[1];
+ if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
+ !GetOpWithOutput(*model, previous_op_shape)) {
+ model->arrays.erase(previous_op_shape);
+ }
+ model->operators.erase(previous_op_it);
+ }
+
+ // We may have just invalidated matmul_it, so let's refresh it now.
+ matmul_it = model->operators.begin();
+ for (; matmul_it != model->operators.end(); ++matmul_it) {
+ if (matmul_it->get() == matmul_op) {
+ break;
+ }
+ }
+ CHECK(matmul_it != model->operators.end());
+ CHECK(matmul_it->get() == matmul_op);
+ } else {
+ AddMessageF("Replacing %s by a FullyConnected operator",
+ LogName(*matmul_op));
+ fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]};
+ }
+
+ // erase the MatMul operator
+ model->operators.erase(matmul_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
new file mode 100644
index 0000000000..cfa5ce0716
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
+ const auto merge_it = model->operators.begin() + op_index;
+ const auto* merge_op = merge_it->get();
+ if (merge_op->type != OperatorType::kTensorFlowMerge) {
+ return false;
+ }
+
+ // We need to yield until this Merge node has only 1 input, which will mean
+ // that that is the selected input. Other graph transformations on other nodes
+ // such as ResolveTensorFlowSwitch, will take care of trimming the
+ // non-selected inputs, so that at some point there will be only 1 input left.
+ if (merge_op->inputs.size() > 1) {
+ AddMessageF("Waiting for %s to be resolved", LogName(*merge_op));
+ return false;
+ }
+
+ // Now that the merge node has 1 input exactly, it is the same as an Identity
+ // node and can be resolved trivially.
+ CHECK_EQ(merge_op->inputs.size(), 1);
+
+ // Update the edges of the graph ahead of removing the node.
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == merge_op->outputs[0]) {
+ input = merge_op->inputs[0];
+ }
+ }
+ }
+
+ // Remove the node and its output array.
+ AddMessageF("Removing already-resolved %s", LogName(*merge_op));
+ model->arrays.erase(merge_op->outputs[0]);
+ model->operators.erase(merge_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
new file mode 100644
index 0000000000..1d3f42b5ec
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) {
+ const auto squeeze_it = model->operators.begin() + op_index;
+ const auto* squeeze_op = squeeze_it->get();
+ if (squeeze_op->type != OperatorType::kSqueeze) {
+ return false;
+ }
+
+ CHECK_EQ(squeeze_op->inputs.size(), 1);
+ CHECK_EQ(squeeze_op->outputs.size(), 1);
+
+ // If the output is consumed by a reshape op, it's a trivial squeeze.
+ if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) {
+ const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]);
+ if (next_op->type == OperatorType::kTensorFlowReshape) {
+ AddMessageF(
+ "%s is trivial because its output is only consumed by a "
+ "Reshape op",
+ LogName(*squeeze_op));
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+ }
+
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
new file mode 100644
index 0000000000..55adfca037
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
+ const auto switch_it = model->operators.begin() + op_index;
+ const auto* switch_op = switch_it->get();
+ if (switch_op->type != OperatorType::kTensorFlowSwitch) {
+ return false;
+ }
+
+ CHECK_EQ(switch_op->inputs.size(), 2);
+ CHECK_EQ(switch_op->outputs.size(), 2);
+ const string& predicate_name = switch_op->inputs[1];
+ // If the predicate array hasn't been resolved to a constant yet,
+ // we need to yield.
+ if (!IsConstantParameterArray(*model, predicate_name)) {
+ AddMessageF(
+ "Waiting for the boolean predicate of %s to be resolved to a constant",
+ LogName(*switch_op));
+ return false;
+ }
+
+ // The predicate should be boolean, and should consist of a single value.
+ const auto& predicate_array = model->GetArray(predicate_name);
+ CHECK(predicate_array.data_type == ArrayDataType::kBool);
+ for (const auto& dim : predicate_array.shape().dims()) {
+ CHECK_EQ(dim, 1);
+ }
+
+ // Obtain the predicate boolean value.
+ const auto& predicate_data =
+ predicate_array.GetBuffer<ArrayDataType::kBool>().data;
+ CHECK_EQ(predicate_data.size(), 1);
+ const bool predicate_value = predicate_data[0];
+
+ // From the TensorFlow docs on .switch() in
+ // third_party/tensorflow/python/ops/control_flow_ops.py
+ //
+ // If `pred` is false, the `data` input is forwared to the first output.
+ // Otherwise, the data goes to the second output.
+ //
+ // Note that this comment used to say the opposite and was recently fixed:
+ // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c
+ const int selected_output_index = predicate_value ? 1 : 0;
+ const int nonselected_output_index = predicate_value ? 0 : 1;
+
+ // Update the edges of the graph ahead of removing the node:
+ // edges that were pointing to the selected output, should instead
+ // point to the input of the Switch node.
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == switch_op->outputs[selected_output_index]) {
+ input = switch_op->inputs[0];
+ }
+ }
+ }
+
+ // There remains to handle the edges that were pointing to the nonselected
+ // output. We will just discard those edges. Concretely, at the moment,
+ // our only examples of graphs with Switch nodes have them feeding into Merge
+ // nodes, so what we're saying here is that we'll make the convention,
+ // in our toco internal representation, that Merge nodes with only 1 input
+ // are Merge nodes that have been resolved already and should be have as
+ // Identity nodes, simply forwarding their input.
+ //
+ for (const auto& other_op : model->operators) {
+ auto input_it = other_op->inputs.begin();
+ while (input_it != other_op->inputs.end()) {
+ if (*input_it == switch_op->outputs[nonselected_output_index]) {
+ // Let us guard our assumption that only Merge nodes consume the outputs
+ // of Switch nodes:
+ CHECK(other_op->type == OperatorType::kTensorFlowMerge);
+ input_it = other_op->inputs.erase(input_it);
+ } else {
+ ++input_it;
+ }
+ }
+ }
+
+ // Remove the output arrays if they are now unused.
+ for (int i = 0; i < 2; i++) {
+ if (!GetOpWithInput(*model, switch_op->outputs[i])) {
+ model->arrays.erase(switch_op->outputs[i]);
+ }
+ }
+ // Remove input arrays if they are only used by the switch itself and aren't
+ // the output of another op (will get handled by RemoveUnusedOp in that case).
+ for (const auto& input : switch_op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1 &&
+ !GetOpWithOutput(*model, input)) {
+ model->arrays.erase(input);
+ }
+ }
+ // Remove the switch node itself.
+ AddMessageF("Removing already-resolved %s", LogName(*switch_op));
+ model->operators.erase(switch_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
new file mode 100644
index 0000000000..9f7e7c42a2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
@@ -0,0 +1,97 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
+ int operand_index) {
+ CHECK(tile_op->type == OperatorType::kTensorFlowTile);
+ CHECK_EQ(binary_op->inputs.size(), 2);
+ CHECK_EQ(tile_op->inputs.size(), 2);
+ const string tile_multiplier_array = tile_op->inputs[1];
+ const string tile_output_array = tile_op->outputs[0];
+ binary_op->inputs[operand_index] = tile_op->inputs[0];
+ auto tile_it = model->operators.begin();
+ for (; tile_it != model->operators.end(); ++tile_it) {
+ if (tile_it->get() == tile_op) {
+ break;
+ }
+ }
+ CHECK(tile_it != model->operators.end());
+ CHECK(tile_it->get() == tile_op);
+ model->operators.erase(tile_it);
+ if (!CountOpsWithInput(*model, tile_multiplier_array) &&
+ !GetOpWithOutput(*model, tile_multiplier_array)) {
+ model->arrays.erase(tile_multiplier_array);
+ }
+ if (!CountOpsWithInput(*model, tile_output_array)) {
+ model->arrays.erase(tile_output_array);
+ }
+}
+} // namespace
+
+bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->inputs.size() != 2) {
+ return false;
+ }
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ Operator* const op[2] = {
+ GetOpWithOutput(*model, binary_op->inputs[0]),
+ GetOpWithOutput(*model, binary_op->inputs[1]),
+ };
+
+ // In the unlikely case where both operands are Tile, we can't infer the
+ // output
+ // size without the Tile nodes, so we have to bail out.
+ if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] &&
+ op[1]->type == OperatorType::kTensorFlowTile) {
+ return false;
+ }
+
+ for (int i = 0; i < 2; i++) {
+ if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) {
+ // We can only remove a Tile operator is no other op than the present
+ // binary op was consuming its tiled output.
+ if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) {
+ AddMessageF("Removing %s", LogName(*op[i]));
+ RemoveTileOperator(model, op[i], binary_op, i);
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
new file mode 100644
index 0000000000..8931498782
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+tf_cc_test(
+ name = "resolve_constant_concatenation_test",
+ srcs = ["resolve_constant_concatenation_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
new file mode 100644
index 0000000000..c6705ad305
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -0,0 +1,221 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+//#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+// A gmock matcher that check that elements of a float vector match to a given
+// tolerance.
+std::vector<testing::Matcher<float>> ArrayFloatNear(
+ const std::vector<float>& values, float max_abs_error = 1e-5) {
+ std::vector<testing::Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(testing::FloatNear(v, max_abs_error));
+ }
+ return matchers;
+}
+} // namespace
+
+// The following 3 tests make sure the concatenation operation on different axis
+// values match TensorFlow results listed below:
+//
+// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
+// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]]
+// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]]
+// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]]
+//
+// ConcatAtAxis0 test:
+// t0 = tf.concat([x0, x1, x2, x3], 0)
+// [[[ 0 1]
+// [ 2 3]]
+//
+// [[ 4 5]
+// [ 6 7]]
+//
+// [[10 11]
+// [12 13]]
+//
+// [[14 15]
+// [16 17]]
+//
+// [[20 21]
+// [22 23]]
+//
+// [[24 25]
+// [26 27]]
+//
+// [[30 31]
+// [32 33]]
+//
+// [[34 35]
+// [36 37]]]
+//
+// ConcatAtAxis1 test:
+// t1 = tf.concat([x0, x1, x2, x3], 1)
+// [[[ 0 1]
+// [ 2 3]
+// [10 11]
+// [12 13]
+// [20 21]
+// [22 23]
+// [30 31]
+// [32 33]]
+//
+// [[ 4 5]
+// [ 6 7]
+// [14 15]
+// [16 17]
+// [24 25]
+// [26 27]
+// [34 35]
+// [36 37]]]
+//
+// ConcatAtAxis2 test:
+// t2 = tf.concat([x0, x1, x2, x3], 2)
+// [[[ 0 1 10 11 20 21 30 31]
+// [ 2 3 12 13 22 23 32 33]]
+//
+// [[ 4 5 14 15 24 25 34 35]
+// [ 6 7 16 17 26 27 36 37]]]
+
+class ResolveConstantConcatenationTest : public ::testing::Test {
+ protected:
+ ResolveConstantConcatenationTest() {}
+
+ // Prepare a hypothetical TOCO model with one Concatenation operator in it
+ // together with 4 arrays as its inputs.
+ // It receives the dimension of concatenation as input.
+ void PrepareModel(Model* model, int concat_dim) {
+ std::vector<string> concat_input_names = {"array0", "array1", "array2",
+ "array3"};
+
+ const int kDim = 3;
+ const int kElementPerDim = 2;
+ const int kBufSize = 8;
+ const int kNumArrays = 4;
+ static float in_buf[kNumArrays][kBufSize] = {
+ {0., 1., 2., 3., 4., 5., 6., 7.},
+ {10., 11., 12., 13., 14., 15., 16., 17.},
+ {20., 21., 22., 23., 24., 25., 26., 27.},
+ {30., 31., 32., 33., 34., 35., 36., 37.}};
+ int cnt = 0;
+ for (const string& concat_input_name : concat_input_names) {
+ Array& in_array = model->GetOrCreateArray(concat_input_name);
+ in_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* in_array_shape = in_array.mutable_shape();
+ std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
+ for (int i = 0; i < kDim; i++) {
+ in_array_shape_dim->push_back(kElementPerDim);
+ }
+ auto& in_array_buffer =
+ in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
+ in_array_buffer.data.resize(kBufSize);
+ float* buf_ptr =
+ in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data();
+ std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr);
+ cnt++;
+ }
+ auto* concatenation_op = new ConcatenationOperator;
+ concatenation_op->concat_dim = concat_dim;
+ concatenation_op->inputs = concat_input_names;
+ concatenation_op->outputs = {"concat_op_outputs"};
+ Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
+ out_array.data_type = ArrayDataType::kFloat;
+ Shape* out_array_shape = out_array.mutable_shape();
+ std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
+ out_array_shape_dim->resize(kDim);
+ for (int i = 0; i < kDim; i++) {
+ if (i == concat_dim) {
+ (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim;
+ } else {
+ (*out_array_shape_dim)[i] = kElementPerDim;
+ }
+ }
+ model->operators.push_back(std::unique_ptr<Operator>(concatenation_op));
+ }
+};
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
+ Model model;
+ const int concat_dim = 0;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
+ 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
+ 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.})));
+}
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
+ Model model;
+ const int concat_dim = 1;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22.,
+ 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15.,
+ 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.})));
+}
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
+ Model model;
+ const int concat_dim = 2;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12.,
+ 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25.,
+ 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.})));
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
new file mode 100644
index 0000000000..4e273343df
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If a conv operation has an im2col array, yield: it should be dropped first.
+ if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) {
+ return false;
+ }
+
+ Operator* ac_op = nullptr;
+ switch (op->fused_activation_function) {
+ case FusedActivationFunctionType::kRelu:
+ ac_op = new ReluOperator;
+ break;
+ case FusedActivationFunctionType::kRelu6:
+ ac_op = new Relu6Operator;
+ break;
+ case FusedActivationFunctionType::kRelu1:
+ ac_op = new Relu1Operator;
+ break;
+ default:
+ return false;
+ }
+
+ // At this point we know that the op has a fused activation function. At the
+ // moment that only happens with ops having a single output, may be
+ // relaxed in the future.
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // Emplace unfused activation function, drop the fused one.
+ model->operators.emplace(it + 1, ac_op);
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+
+ // Wire up arrays, constructing a new intermediate array to connect the
+ // op to its new unfused activation function.
+ ac_op->outputs = op->outputs;
+ const string& tmp_array_name =
+ AvailableArrayName(*model, op->outputs[0] + "_unfused");
+ CHECK(!model->arrays.count(tmp_array_name));
+ model->GetOrCreateArray(tmp_array_name);
+ ac_op->inputs = {tmp_array_name};
+ op->outputs = {tmp_array_name};
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
new file mode 100644
index 0000000000..c889149ada
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -0,0 +1,1508 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/map.h"
+#include "google/protobuf/text_format.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/strip.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+using tensorflow::AttrValue;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_UINT8;
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+using tensorflow::TensorProto;
+using tensorflow::TensorShapeProto;
+
+namespace toco {
+namespace {
+bool HasAttr(const NodeDef& node, const string& attr_name) {
+ return node.attr().count(attr_name) > 0;
+}
+
+const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kS);
+ return attr.s();
+}
+
+int GetIntAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
+ << node.DebugString();
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kI);
+ return attr.i();
+}
+
+float GetFloatAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kF);
+ return attr.f();
+}
+
+bool GetBoolAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kB);
+ return attr.b();
+}
+
+tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
+ const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kType);
+ return attr.type();
+}
+
+const TensorShapeProto& GetShapeAttr(const NodeDef& node,
+ const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kShape);
+ return attr.shape();
+}
+
+const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kTensor);
+ return attr.tensor();
+}
+
+const AttrValue::ListValue& GetListAttr(const NodeDef& node,
+ const string& attr_name) {
+ CHECK(HasAttr(node, attr_name));
+ const auto& attr = node.attr().at(attr_name);
+ CHECK_EQ(attr.value_case(), AttrValue::kList);
+ return attr.list();
+}
+
+ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
+ if (dtype == DT_UINT8)
+ return ArrayDataType::kUint8;
+ else if (dtype == DT_FLOAT)
+ return ArrayDataType::kFloat;
+ else if (dtype == DT_BOOL)
+ return ArrayDataType::kBool;
+ else if (dtype == DT_INT32)
+ return ArrayDataType::kInt32;
+ else if (dtype == DT_INT64)
+ return ArrayDataType::kInt64;
+ else
+ LOG(INFO) << "Unsupported data type in placehoder op: " << dtype;
+ return ArrayDataType::kNone;
+}
+
+void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
+ tensorflow::TensorShapeProto_Dim>& input_dims,
+ Shape* shape) {
+ std::vector<int> input_dims_only_sizes;
+ for (auto& d : input_dims) {
+ if (d.size() == 0) {
+ // Some TensorFlow shapes contain a 0 dim, effectively making
+ // them of flat size 0 even though they have other nonzero dims.
+ // This breaks our invariant, that array dims can't be 0.
+ // For now, tweaking this to record a 0-D shape instead.
+ input_dims_only_sizes.clear();
+ break;
+ }
+ input_dims_only_sizes.push_back(d.size());
+ }
+ *shape->mutable_dims() = input_dims_only_sizes;
+}
+
+void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 4);
+ ImportShape(input_shape.dim(), output_array->mutable_shape());
+ int input_flat_size = 1;
+ for (int k = 0; k < input_shape.dim_size(); k++) {
+ input_flat_size *= input_shape.dim(k).size();
+ }
+ auto& output_float_data =
+ output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
+ output_float_data.resize(input_flat_size);
+ if (input_tensor.float_val_size()) {
+ for (int i = 0; i < input_tensor.float_val_size(); i++) {
+ output_float_data[i] = input_tensor.float_val(i);
+ }
+ } else if (input_tensor.tensor_content().size() ==
+ input_flat_size * sizeof(float)) {
+ toco::port::CopyToBuffer(input_tensor.tensor_content(),
+ reinterpret_cast<char*>(output_float_data.data()));
+ } else {
+ LOG(FATAL) << "Neither input_content nor float_val have the right "
+ "dimensions for this float tensor.";
+ }
+}
+
+void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_INT32);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 4);
+ ImportShape(input_shape.dim(), output_array->mutable_shape());
+ int input_flat_size = 1;
+ for (int k = 0; k < input_shape.dim_size(); k++) {
+ input_flat_size *= input_shape.dim(k).size();
+ }
+ auto& output_int_data =
+ output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
+ output_int_data.resize(input_flat_size);
+ if (input_tensor.int_val_size()) {
+ for (int i = 0; i < input_tensor.int_val_size(); i++) {
+ output_int_data[i] = input_tensor.int_val(i);
+ }
+ } else if (input_tensor.tensor_content().size() ==
+ input_flat_size * sizeof(int32)) {
+ toco::port::CopyToBuffer(input_tensor.tensor_content(),
+ reinterpret_cast<char*>(output_int_data.data()));
+ } else {
+ LOG(FATAL) << "Neither input_content nor int_val have the right "
+ "dimensions for this int32 tensor.";
+ }
+}
+
+void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+ CHECK_EQ(input_tensor.dtype(), DT_INT64);
+ const auto& input_shape = input_tensor.tensor_shape();
+ CHECK_LE(input_shape.dim_size(), 4);
+ ImportShape(input_shape.dim(), output_array->mutable_shape());
+ int input_flat_size = 1;
+ for (int k = 0; k < input_shape.dim_size(); k++) {
+ input_flat_size *= input_shape.dim(k).size();
+ }
+ auto& output_int_data =
+ output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
+ output_int_data.resize(input_flat_size);
+ if (input_tensor.int64_val_size()) {
+ for (int i = 0; i < input_tensor.int64_val_size(); i++) {
+ output_int_data[i] = input_tensor.int64_val(i);
+ }
+ } else if (input_tensor.tensor_content().size() ==
+ input_flat_size * sizeof(int64)) {
+ toco::port::CopyToBuffer(input_tensor.tensor_content(),
+ reinterpret_cast<char*>(output_int_data.data()));
+ } else {
+ LOG(FATAL) << "Neither input_content nor int64_val have the right "
+ "dimensions for this int64 tensor.";
+ }
+}
+
+// Count the number of inputs of a given node. If `drop_control_dependency` is
+// true, count the number of non-control-dependency inputs.
+size_t GetInputsCount(const NodeDef& node, bool drop_control_dependency) {
+ if (drop_control_dependency) {
+ for (size_t i = 0; i < node.input_size(); ++i) {
+ if (node.input(i)[0] == '^') {
+ LOG(INFO) << "Reached first control dependency input: "
+ << node.input(i);
+ return i;
+ }
+ }
+ return node.input_size();
+ } else {
+ return node.input_size();
+ }
+}
+
+void ConvertConstOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Const");
+ const auto& tensor = GetTensorAttr(node, "value");
+ const auto dtype = GetDataTypeAttr(node, "dtype");
+
+ auto& array = model->GetOrCreateArray(node.name());
+ array.data_type = dtype == DT_FLOAT
+ ? ArrayDataType::kFloat
+ : dtype == DT_INT32
+ ? ArrayDataType::kInt32
+ : dtype == DT_INT64 ? ArrayDataType::kInt64
+ : ArrayDataType::kNone;
+ if (dtype == DT_FLOAT) {
+ ImportFloatArray(tensor, &array);
+ } else if (dtype == DT_INT32) {
+ ImportInt32Array(tensor, &array);
+ } else if (dtype == DT_INT64) {
+ ImportInt64Array(tensor, &array);
+ } else {
+ // do nothing, silently ignore the Const data. For example, there are consts
+ // of string type. We just make a dummy buffer to indicate that this array
+ // does not rely on external input.
+ array.GetMutableBuffer<ArrayDataType::kNone>();
+ }
+}
+
+void ConvertConvOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Conv2D");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+
+ // We only support NHWC, which is the default data_format.
+ // So if data_format is not defined, we're all good.
+ if (node.attr().count("data_format")) {
+ CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
+ }
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+
+ const auto& input_name = node.input(0);
+ const auto& weights_name = node.input(1);
+ const auto& reordered_weights_name = weights_name + "_reordered";
+ // Check if a ReorderAxesOperator was already created for these weights
+ // (that happens when multiple layers share the same weights).
+ const Operator* existing_reorder =
+ GetOpWithOutput(*model, reordered_weights_name);
+ if (existing_reorder) {
+ // Check that it is safe to rely on the _reordered naming of the output
+ // array!
+ CHECK(existing_reorder->type == OperatorType::kReorderAxes);
+ } else {
+ // Create a new ReorderAxesOperator
+ auto* reorder = new ReorderAxesOperator;
+ reorder->inputs = {weights_name};
+ reorder->outputs = {reordered_weights_name};
+ reorder->input_axes_order = AxesOrder::kHWIO;
+ reorder->output_axes_order = AxesOrder::kOHWI;
+ model->operators.emplace_back(reorder);
+ }
+ auto* conv = new ConvOperator;
+ conv->inputs = {input_name, reordered_weights_name};
+ conv->outputs = {node.name()};
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ conv->stride_height = strides.i(1);
+ conv->stride_width = strides.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ conv->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ conv->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(conv);
+}
+
+void ConvertDepthwiseConvOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "DepthwiseConv2dNative");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+
+ // We only support NHWC, which is the default data_format.
+ // So if data_format is not defined, we're all good.
+ if (node.attr().count("data_format")) {
+ CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
+ }
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+
+ const auto& input_name = node.input(0);
+ const auto& weights_name = node.input(1);
+ const auto& reordered_weights_name = weights_name + "_reordered";
+ // Check if a ReorderAxesOperator was already created for these weights
+ // (that happens when multiple layers share the same weights).
+ const Operator* existing_reorder =
+ GetOpWithOutput(*model, reordered_weights_name);
+ if (existing_reorder) {
+ // Check that it is safe to rely on the _reordered naming of the output
+ // array!
+ CHECK(existing_reorder->type == OperatorType::kReorderAxes);
+ } else {
+ // Create a new ReorderAxesOperator
+ auto* reorder = new ReorderAxesOperator;
+ reorder->inputs = {weights_name};
+ reorder->outputs = {reordered_weights_name};
+ reorder->input_axes_order = AxesOrder::kHWIM;
+ reorder->output_axes_order = AxesOrder::k1HWO;
+ model->operators.emplace_back(reorder);
+ }
+ auto* conv = new DepthwiseConvOperator;
+ conv->inputs = {input_name, reordered_weights_name};
+ conv->outputs = {node.name()};
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ conv->stride_height = strides.i(1);
+ conv->stride_width = strides.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ conv->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ conv->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(conv);
+}
+
+void ConvertDepthToSpaceOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "DepthToSpace");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* op = new DepthToSpaceOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ op->block_size = GetIntAttr(node, "block_size");
+ QCHECK_GE(op->block_size, 2);
+ model->operators.emplace_back(op);
+}
+
+void ConvertSpaceToDepthOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "SpaceToDepth");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* op = new SpaceToDepthOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ op->block_size = GetIntAttr(node, "block_size");
+ QCHECK_GE(op->block_size, 2);
+ model->operators.emplace_back(op);
+}
+
+void ConvertBiasAddOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "BiasAdd");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ const auto& input_name = node.input(0);
+ const auto& bias_name = node.input(1);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* biasadd = new AddOperator;
+ biasadd->inputs.push_back(input_name);
+ biasadd->inputs.push_back(bias_name);
+ biasadd->outputs.push_back(node.name());
+ model->operators.emplace_back(biasadd);
+}
+
+void ConvertReluOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Relu");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* relu = new ReluOperator;
+ relu->inputs.push_back(input_name);
+ relu->outputs.push_back(node.name());
+ model->operators.emplace_back(relu);
+}
+
+void ConvertRelu6Operator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Relu6");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* op = new Relu6Operator;
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertLogisticOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sigmoid");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* op = new LogisticOperator;
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertTanhOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Tanh");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* op = new TanhOperator;
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertDivOperator(const NodeDef& node, Model* model) {
+ CHECK(node.op() == "Div" || node.op() == "RealDiv");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new DivOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertIdentityOperator(const NodeDef& node, Model* model) {
+ CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
+ node.op() == "PlaceholderWithDefault");
+ auto* op = new TensorFlowIdentityOperator;
+ // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
+ // identity nodes with multiple inputs, but the other inputs seem
+ // to be gratuitous (in the case of rajeev_lstm.pb, these are
+ // enumerating the LSTM state arrays). We will just ignore extra
+ // inputs beyond the first input.
+ CHECK_GE(node.input_size(), 1);
+ const auto& input_name = node.input(0);
+ op->inputs.push_back(input_name);
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertFakeQuantWithMinMaxArgs(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new FakeQuantOperator;
+ op->inputs.push_back(node.input(0));
+ op->minmax.reset(new MinMax);
+ auto& minmax = *op->minmax;
+ minmax.min = GetFloatAttr(node, "min");
+ minmax.max = GetFloatAttr(node, "max");
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertFakeQuantWithMinMaxVars(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ CHECK(num_inputs == 3 || num_inputs == 4);
+ auto* op = new FakeQuantOperator;
+ for (int i = 0; i < 3; i++) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertRsqrtOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Rsqrt");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowRsqrtOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSqrtOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sqrt");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowSqrtOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSqueezeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Squeeze");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new SqueezeOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+
+ const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
+ for (int i = 0; i < squeeze_dims.i_size(); ++i) {
+ op->squeeze_dims.push_back(squeeze_dims.i(i));
+ }
+
+ model->operators.emplace_back(op);
+}
+
+void ConvertSquareOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Square");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowSquareOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertAddOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Add");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new AddOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMulOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Mul");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new MulOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSubOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sub");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new SubOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSumOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Sum");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowSumOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertTileOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Tile");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowTileOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSliceOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Slice");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3);
+ auto* op = new SliceOperator;
+ for (int i = 0; i < 3; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertPadOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Pad");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new PadOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertShapeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Shape");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ auto* op = new TensorFlowShapeOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSplitOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Split");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowSplitOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ const int num_split = GetIntAttr(node, "num_split");
+ op->outputs.push_back(node.name());
+ for (int i = 1; i < num_split; i++) {
+ op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+ }
+ op->num_split = num_split;
+ model->operators.emplace_back(op);
+}
+
+void ConvertMergeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Merge");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMergeOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSwitchOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Switch");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowSwitchOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ // Switch operators have two outputs: "name" and "name:1".
+ op->outputs.push_back(node.name() + ":1");
+ model->operators.emplace_back(op);
+}
+void ConvertSoftmaxOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Softmax");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* softmax = new SoftmaxOperator;
+ softmax->inputs.push_back(input_name);
+ softmax->outputs.push_back(node.name());
+ // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
+ CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
+ softmax->beta = 1.f;
+ model->operators.emplace_back(softmax);
+}
+
+void ConvertLRNOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "LRN");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ auto* lrn = new LocalResponseNormalizationOperator;
+ lrn->inputs.push_back(input_name);
+ lrn->outputs.push_back(node.name());
+ lrn->range = GetIntAttr(node, "depth_radius");
+ lrn->bias = GetFloatAttr(node, "bias");
+ lrn->alpha = GetFloatAttr(node, "alpha");
+ lrn->beta = GetFloatAttr(node, "beta");
+ model->operators.emplace_back(lrn);
+}
+
+void ConvertMaxPoolOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "MaxPool");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ if (HasAttr(node, "T")) {
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ } else {
+ LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
+ }
+ auto* maxpool = new MaxPoolOperator;
+ maxpool->inputs.push_back(input_name);
+ maxpool->outputs.push_back(node.name());
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ maxpool->stride_height = strides.i(1);
+ maxpool->stride_width = strides.i(2);
+ const auto& ksize = GetListAttr(node, "ksize");
+ CHECK_EQ(ksize.i_size(), 4);
+ CHECK_EQ(ksize.i(0), 1);
+ CHECK_EQ(ksize.i(3), 1);
+ maxpool->kheight = ksize.i(1);
+ maxpool->kwidth = ksize.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ maxpool->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ maxpool->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(maxpool);
+}
+
+void ConvertAvgPoolOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "AvgPool");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto& input_name = node.input(0);
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ auto* avgpool = new AveragePoolOperator;
+ avgpool->inputs.push_back(input_name);
+ avgpool->outputs.push_back(node.name());
+ const auto& strides = GetListAttr(node, "strides");
+ CHECK_EQ(strides.i_size(), 4);
+ CHECK_EQ(strides.i(0), 1);
+ CHECK_EQ(strides.i(3), 1);
+ avgpool->stride_height = strides.i(1);
+ avgpool->stride_width = strides.i(2);
+ const auto& ksize = GetListAttr(node, "ksize");
+ CHECK_EQ(ksize.i_size(), 4);
+ CHECK_EQ(ksize.i(0), 1);
+ CHECK_EQ(ksize.i(3), 1);
+ avgpool->kheight = ksize.i(1);
+ avgpool->kwidth = ksize.i(2);
+ const auto& padding = GetStringAttr(node, "padding");
+ if (padding == "SAME") {
+ avgpool->padding.type = PaddingType::kSame;
+ } else if (padding == "VALID") {
+ avgpool->padding.type = PaddingType::kValid;
+ } else {
+ LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ }
+ model->operators.emplace_back(avgpool);
+}
+
+void ConvertReshapeOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Reshape");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowReshapeOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMatMulOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "MatMul");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ // Transpose flags should be easy to support, but we don't have a
+ // GraphDef with them to test on at the moment.
+ CHECK_EQ(GetBoolAttr(node, "transpose_a"), false);
+ CHECK_EQ(GetBoolAttr(node, "transpose_b"), false);
+ const auto& input_name = node.input(0);
+ const auto& weights_name = node.input(1);
+ const auto& reordered_weights_name = weights_name + "_reordered";
+ // Check if a ReorderAxesOperator was already created for these weights
+ // (that happens when multiple layers share the same weights).
+ const Operator* existing_reorder =
+ GetOpWithOutput(*model, reordered_weights_name);
+ if (existing_reorder) {
+ // Check that it is safe to rely on the _reordered naming of the output
+ // array!
+ CHECK(existing_reorder->type == OperatorType::kReorderAxes);
+ } else {
+ // Create a new ReorderAxesOperator
+ auto* reorder = new ReorderAxesOperator;
+ reorder->inputs = {weights_name};
+ reorder->outputs = {reordered_weights_name};
+ reorder->input_axes_order = AxesOrder::kRC;
+ reorder->output_axes_order = AxesOrder::kCR;
+ model->operators.emplace_back(reorder);
+ }
+ auto* matmul = new TensorFlowMatMulOperator;
+ matmul->inputs = {input_name, reordered_weights_name};
+ matmul->outputs = {node.name()};
+ model->operators.emplace_back(matmul);
+}
+
+void ConvertConcatOperator(const NodeDef& node, Model* model) {
+ Operator* op = nullptr;
+ if (node.op() == "Concat") {
+ op = new TensorFlowConcatOperator;
+ } else if (node.op() == "ConcatV2") {
+ op = new TensorFlowConcatV2Operator;
+ } else {
+ LOG(FATAL) << "Expected Concat or ConcatV2";
+ }
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ CHECK_GE(num_inputs, 2);
+ CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertAllOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "All");
+ auto* op = new TensorFlowAllOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertAssertOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Assert");
+ auto* op = new TensorFlowAssertOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertLessOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Less");
+ auto* op = new TensorFlowLessOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertLessEqualOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "LessEqual");
+ auto* op = new TensorFlowLessEqualOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertGreaterOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Greater");
+ auto* op = new TensorFlowGreaterOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertGreaterEqualOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "GreaterEqual");
+ auto* op = new TensorFlowGreaterEqualOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMaxOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Max");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMaxOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMinOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Min");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMinOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMaximumOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Maximum");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMaximumOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMinimumOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Minimum");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new TensorFlowMinimumOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertUnsupportedOperator(const NodeDef& node, Model* model) {
+ LOG(INFO) << "Converting unsupported operation: " << node.op();
+ auto* op = new TensorFlowUnsupportedOperator;
+ const int num_inputs =
+ GetInputsCount(node, model->flags.drop_control_dependency());
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ op->tensorflow_op = node.op();
+ node.SerializeToString(&op->tensorflow_node_def);
+ model->operators.emplace_back(op);
+ if (HasAttr(node, "_output_quantized")) {
+ op->quantized = GetBoolAttr(node, "_output_quantized");
+ }
+ if (HasAttr(node, "_output_types")) {
+ const auto& output_types = GetListAttr(node, "_output_types");
+ for (int i = 0; i < output_types.type_size(); ++i) {
+ op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
+ }
+ }
+}
+
+void ConvertStridedSliceOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "StridedSlice");
+ CHECK_EQ(node.input_size(), 4);
+
+ // Only a subset of the full TF op functionality is supported now.
+ if ( // No 64-bit indices.
+ GetDataTypeAttr(node, "Index") != DT_INT32 ||
+ // No dimensionality changes.
+ GetIntAttr(node, "new_axis_mask") != 0 ||
+ GetIntAttr(node, "shrink_axis_mask") != 0 ||
+ // No sparse indices.
+ GetIntAttr(node, "ellipsis_mask") != 0 ||
+ // Only 4D tensors are supported.
+ GetIntAttr(node, "begin_mask") > 15 ||
+ GetIntAttr(node, "end_mask") > 15) {
+ ConvertUnsupportedOperator(node, model);
+ return;
+ }
+
+ auto* op = new StridedSliceOperator;
+ for (const auto& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+
+ op->begin_mask = GetIntAttr(node, "begin_mask");
+ op->ellipsis_mask = GetIntAttr(node, "ellipsis_mask");
+ op->end_mask = GetIntAttr(node, "end_mask");
+ op->new_axis_mask = GetIntAttr(node, "new_axis_mask");
+ op->shrink_axis_mask = GetIntAttr(node, "shrink_axis_mask");
+ model->operators.emplace_back(op);
+}
+
+void ConvertPlaceholderOperator(const NodeDef& node, Model* model) {
+ CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
+ if (node.op() == "Placeholder") {
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 0);
+ }
+ auto& array = model->GetOrCreateArray(node.name());
+ if (node.attr().count("dtype")) {
+ array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
+ }
+ if (node.attr().count("shape")) {
+ const auto& shape = GetShapeAttr(node, "shape");
+ auto num_dims = shape.dim_size();
+ bool has_wildcard = false;
+ for (std::size_t i = 0; i < num_dims; i++) {
+ if (shape.dim(i).size() == -1) {
+ has_wildcard = true;
+ }
+ }
+ // TODO(b/62716978): This logic needs to be revisted. During dims
+ // refactoring it is an interim fix.
+ if (num_dims > 0 && !has_wildcard) {
+ auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
+ dst_array_dims.resize(num_dims);
+ for (std::size_t i = 0; i < num_dims; i++) {
+ dst_array_dims[i] = shape.dim(i).size();
+ }
+ }
+ }
+}
+
+void ConvertNoOpOperator(const NodeDef& node, Model* model) {}
+
+ArrayDataType GetArrayDataType(tensorflow::DataType tf_data_type) {
+ if (tf_data_type == DT_UINT8) {
+ return ArrayDataType::kUint8;
+ } else if (tf_data_type == DT_INT32) {
+ return ArrayDataType::kInt32;
+ } else if (tf_data_type == DT_FLOAT) {
+ return ArrayDataType::kFloat;
+ } else {
+ return ArrayDataType::kNone;
+ }
+}
+
+void ConvertCastOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Cast");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
+ const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
+ CHECK(tf_src_dtype == DT_UINT8 || tf_src_dtype == DT_INT32 ||
+ tf_src_dtype == DT_FLOAT);
+ CHECK(tf_dst_dtype == DT_UINT8 || tf_dst_dtype == DT_INT32 ||
+ tf_dst_dtype == DT_FLOAT);
+ CHECK_NE(tf_src_dtype, tf_dst_dtype)
+ << "Same input and output data type. No need to cast.";
+ auto* op = new CastOperator;
+ op->src_data_type = GetArrayDataType(tf_src_dtype);
+ op->dst_data_type = GetArrayDataType(tf_dst_dtype);
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertFloorOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Floor");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1);
+ const auto data_type = GetDataTypeAttr(node, "T");
+ CHECK(data_type == DT_FLOAT);
+ auto* op = new FloorOperator;
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertGatherOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Gather");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
+ CHECK(indices_data_type == DT_INT32);
+ auto* op = new GatherOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertResizeBilinearOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "ResizeBilinear");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2);
+ auto* op = new ResizeBilinearOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef& node,
+ Model* model) {
+ CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 5);
+
+ // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
+ // to the input, before feeding it into TensorFlowRsqrtOperator.
+ // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
+
+ string multiplier = node.name() + "_mul";
+ if (GetBoolAttr(node, "scale_after_normalization")) {
+ // Create graph:
+ // v -> RSQRT ->
+ // MUL -> multiplier
+ // gamma ----->
+ string rsqrt = node.name() + "_rsqrt";
+
+ auto* rsqrt_op = new TensorFlowRsqrtOperator;
+ rsqrt_op->inputs.push_back(node.input(2));
+ rsqrt_op->outputs.push_back(rsqrt);
+ model->operators.emplace_back(rsqrt_op);
+
+ auto* mul_op = new MulOperator;
+ mul_op->inputs.push_back(rsqrt);
+ mul_op->inputs.push_back(node.input(4));
+ mul_op->outputs.push_back(multiplier);
+ model->operators.emplace_back(mul_op);
+ } else {
+ // Create graph:
+ // v -> RSQRT -> multiplier
+ auto* rsqrt_op = new TensorFlowRsqrtOperator;
+ rsqrt_op->inputs.push_back(node.input(2));
+ rsqrt_op->outputs.push_back(multiplier);
+ model->operators.emplace_back(rsqrt_op);
+ }
+
+ auto* op = new BatchNormalizationOperator;
+ op->global_normalization = true;
+
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(multiplier);
+ op->inputs.push_back(node.input(3));
+ op->outputs.push_back(node.name());
+
+ model->operators.emplace_back(op);
+}
+
+void ConvertFusedBatchNormOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "FusedBatchNorm");
+ CHECK_EQ(node.input_size(), 5);
+
+ // Declare shortcuts for the inputs.
+ const string& gamma_input = node.input(1);
+ const string& beta_input = node.input(2);
+ const string& moving_mean_input = node.input(3);
+ const string& moving_variance_input = node.input(4);
+
+ // Create an array holding the epsilon value (typically, 0.001).
+ const string epsilon_array_name = node.name() + "_epsilon_array";
+ auto& epsilon_array = model->GetOrCreateArray(epsilon_array_name);
+ epsilon_array.data_type = ArrayDataType::kFloat;
+ *epsilon_array.mutable_shape()->mutable_dims() = {1};
+ epsilon_array.GetMutableBuffer<ArrayDataType::kFloat>().data.push_back(
+ GetFloatAttr(node, "epsilon"));
+
+ // Add epsilon to the moving variance.
+ const string epsilon_add_op_name = node.name() + "_epsilon";
+ auto* epsilon_add_op = new AddOperator;
+ epsilon_add_op->inputs.push_back(moving_variance_input);
+ epsilon_add_op->inputs.push_back(epsilon_array_name);
+ epsilon_add_op->outputs.push_back(epsilon_add_op_name);
+ model->operators.emplace_back(epsilon_add_op);
+
+ // Take the inverse square root of the (variance + epsilon).
+ const string rsqrt_op_name = node.name() + "_rsqrt";
+ auto* rsqrt_op = new TensorFlowRsqrtOperator;
+ rsqrt_op->inputs.push_back(epsilon_add_op_name);
+ rsqrt_op->outputs.push_back(rsqrt_op_name);
+ model->operators.emplace_back(rsqrt_op);
+
+ // Multiply the result by gamma.
+ const string multiplier = node.name() + "_mul";
+ auto* mul_op = new MulOperator;
+ mul_op->inputs.push_back(rsqrt_op_name);
+ mul_op->inputs.push_back(gamma_input);
+ mul_op->outputs.push_back(multiplier);
+ model->operators.emplace_back(mul_op);
+
+ // Now we have all required inputs for the BatchNormalizationOperator.
+ auto* op = new BatchNormalizationOperator;
+ op->global_normalization = true;
+
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(moving_mean_input);
+ op->inputs.push_back(multiplier);
+ op->inputs.push_back(beta_input);
+ op->outputs.push_back(node.name());
+
+ model->operators.emplace_back(op);
+}
+
+void ConvertSpaceToBatchNDOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "SpaceToBatchND");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3);
+ CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
+ CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
+ auto* op = new SpaceToBatchNDOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertBatchToSpaceNDOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "BatchToSpaceND");
+ CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3);
+ CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
+ CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
+ auto* op = new BatchToSpaceNDOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertMeanOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Mean");
+ CHECK_EQ(node.input_size(), 2);
+ auto* op = new MeanOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
+void ConvertSvdfOperator(const NodeDef& node, Model* model) {
+ CHECK_EQ(node.op(), "Svdf");
+ bool has_bias = (node.input_size() == 4);
+ auto* op = new SvdfOperator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ if (has_bias) {
+ op->inputs.push_back(node.input(3));
+ }
+ op->outputs.push_back(node.name() + "_state");
+ op->outputs.push_back(node.name());
+ if (node.attr().at("ActivationFunction").s() == "Relu") {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu;
+ } else {
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+ }
+ op->rank = node.attr().at("Rank").i();
+ model->operators.emplace_back(op);
+}
+
+void StripCaretFromArrayNames(Model* model) {
+ for (auto& op : model->operators) {
+ for (auto& input : op->inputs) {
+ input = string(absl::StripPrefix(input, "^"));
+ }
+ for (auto& output : op->outputs) {
+ output = string(absl::StripPrefix(output, "^"));
+ }
+ }
+ for (auto& array : model->arrays) {
+ if (absl::StartsWith(array.first, "^")) {
+ LOG(FATAL) << "What?";
+ }
+ }
+}
+
+void AddExtraOutputsFedIntoOtherOps(Model* model) {
+ for (const auto& consumer_op : model->operators) {
+ for (const string& input : consumer_op->inputs) {
+ const std::vector<string>& split = absl::StrSplit(input, ':');
+ if (split.size() != 2) {
+ continue;
+ }
+ int output_index = 0;
+ if (!absl::SimpleAtoi(split[1], &output_index)) {
+ continue;
+ }
+ auto* producer_op = GetOpWithOutput(*model, split[0]);
+ if (!producer_op) {
+ continue;
+ }
+ while (producer_op->outputs.size() <= output_index) {
+ using toco::port::StringF;
+ producer_op->outputs.push_back(
+ StringF("%s:%d", split[0], producer_op->outputs.size()));
+ }
+ }
+ }
+}
+
+bool InlineAllFunctions(GraphDef* graphdef) {
+ if (graphdef->library().function().empty()) {
+ VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
+ return false;
+ }
+
+ // Override "_noinline" attribute on all functions
+ GraphDef graphdef_copy(*graphdef);
+ for (auto& function :
+ (*graphdef_copy.mutable_library()->mutable_function())) {
+ auto* attributes = function.mutable_attr();
+ if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
+ (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
+ }
+ }
+
+ // Construct minimum resources needed to use ExpandInlineFunctions().
+ tensorflow::SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", 1});
+ std::vector<tensorflow::Device*> devices;
+ TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices));
+
+ tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
+ graphdef_copy.library());
+ tensorflow::DeviceMgr device_mgr(devices);
+ tensorflow::OptimizerOptions o_opts;
+ tensorflow::ProcessFunctionLibraryRuntime pflr(
+ &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
+ o_opts, nullptr);
+ tensorflow::FunctionLibraryRuntime* flr;
+ flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+
+ tensorflow::Graph graph(fld);
+ tensorflow::GraphConstructorOptions gc_opts;
+ TF_CHECK_OK(
+ tensorflow::ConvertGraphDefToGraph(gc_opts, graphdef_copy, &graph));
+
+ // Iterate over the graph until there are no more nodes to be inlined.
+ bool graph_modified = false;
+ while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
+ graph_modified = true;
+ LOG(INFO) << "Found functions that were inlined.";
+ }
+
+ // Output inlined graph
+ if (graph_modified) {
+ graph.ToGraphDef(graphdef);
+ }
+ return graph_modified;
+}
+} // namespace
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(const ModelFlags& model_flags,
+ const GraphDef& tf_graph) {
+ LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
+
+ GraphDef inlined_graph(tf_graph);
+ if (InlineAllFunctions(&inlined_graph)) {
+ LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
+ }
+
+ Model* model = new Model;
+ ResolveModelFlags(model_flags, model);
+
+ for (const auto& node : inlined_graph.node()) {
+ if (node.op() == "Const") {
+ ConvertConstOperator(node, model);
+ } else if (node.op() == "Conv2D") {
+ ConvertConvOperator(node, model);
+ } else if (node.op() == "DepthwiseConv2dNative") {
+ ConvertDepthwiseConvOperator(node, model);
+ } else if (node.op() == "DepthToSpace") {
+ ConvertDepthToSpaceOperator(node, model);
+ } else if (node.op() == "SpaceToDepth") {
+ ConvertSpaceToDepthOperator(node, model);
+ } else if (node.op() == "BiasAdd") {
+ ConvertBiasAddOperator(node, model);
+ } else if (node.op() == "Relu") {
+ ConvertReluOperator(node, model);
+ } else if (node.op() == "Relu6") {
+ ConvertRelu6Operator(node, model);
+ } else if (node.op() == "Sigmoid") {
+ ConvertLogisticOperator(node, model);
+ } else if (node.op() == "Tanh") {
+ ConvertTanhOperator(node, model);
+ } else if (node.op() == "MaxPool") {
+ ConvertMaxPoolOperator(node, model);
+ } else if (node.op() == "AvgPool") {
+ ConvertAvgPoolOperator(node, model);
+ } else if (node.op() == "Reshape") {
+ ConvertReshapeOperator(node, model);
+ } else if (node.op() == "MatMul") {
+ ConvertMatMulOperator(node, model);
+ } else if (node.op() == "Div" || node.op() == "RealDiv") {
+ ConvertDivOperator(node, model);
+ } else if (node.op() == "Identity" || node.op() == "CheckNumerics") {
+ ConvertIdentityOperator(node, model);
+ } else if (node.op() == "FakeQuantWithMinMaxVars") {
+ ConvertFakeQuantWithMinMaxVars(node, model);
+ } else if (node.op() == "FakeQuantWithMinMaxArgs") {
+ ConvertFakeQuantWithMinMaxArgs(node, model);
+ } else if (node.op() == "Rsqrt") {
+ ConvertRsqrtOperator(node, model);
+ } else if (node.op() == "Squeeze") {
+ ConvertSqueezeOperator(node, model);
+ } else if (node.op() == "Sqrt") {
+ ConvertSqrtOperator(node, model);
+ } else if (node.op() == "Square") {
+ ConvertSquareOperator(node, model);
+ } else if (node.op() == "Add") {
+ ConvertAddOperator(node, model);
+ } else if (node.op() == "Mul") {
+ ConvertMulOperator(node, model);
+ } else if (node.op() == "Sub") {
+ ConvertSubOperator(node, model);
+ } else if (node.op() == "Sum") {
+ ConvertSumOperator(node, model);
+ } else if (node.op() == "Tile") {
+ ConvertTileOperator(node, model);
+ } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
+ ConvertConcatOperator(node, model);
+ } else if (node.op() == "LRN") {
+ ConvertLRNOperator(node, model);
+ } else if (node.op() == "Softmax") {
+ ConvertSoftmaxOperator(node, model);
+ } else if (node.op() == "All") {
+ ConvertAllOperator(node, model);
+ } else if (node.op() == "Assert") {
+ ConvertAssertOperator(node, model);
+ } else if (node.op() == "Less") {
+ ConvertLessOperator(node, model);
+ } else if (node.op() == "LessEqual") {
+ ConvertLessEqualOperator(node, model);
+ } else if (node.op() == "Greater") {
+ ConvertGreaterOperator(node, model);
+ } else if (node.op() == "GreaterEqual") {
+ ConvertGreaterEqualOperator(node, model);
+ } else if (node.op() == "Max") {
+ ConvertMaxOperator(node, model);
+ } else if (node.op() == "Min") {
+ ConvertMinOperator(node, model);
+ } else if (node.op() == "Maximum") {
+ ConvertMaximumOperator(node, model);
+ } else if (node.op() == "Minimum") {
+ ConvertMinimumOperator(node, model);
+ } else if (node.op() == "Merge") {
+ ConvertMergeOperator(node, model);
+ } else if (node.op() == "Pad") {
+ ConvertPadOperator(node, model);
+ } else if (node.op() == "StridedSlice") {
+ ConvertStridedSliceOperator(node, model);
+ } else if (node.op() == "Shape") {
+ ConvertShapeOperator(node, model);
+ } else if (node.op() == "Slice") {
+ ConvertSliceOperator(node, model);
+ } else if (node.op() == "Split") {
+ ConvertSplitOperator(node, model);
+ } else if (node.op() == "Switch") {
+ ConvertSwitchOperator(node, model);
+ } else if (node.op() == "Placeholder") {
+ ConvertPlaceholderOperator(node, model);
+ } else if (node.op() == "PlaceholderWithDefault") {
+ ConvertIdentityOperator(node, model);
+ } else if (node.op() == "LegacyFedInput") {
+ ConvertPlaceholderOperator(node, model);
+ } else if (node.op() == "NoOp") {
+ ConvertNoOpOperator(node, model);
+ } else if (node.op() == "Cast") {
+ ConvertCastOperator(node, model);
+ } else if (node.op() == "Floor") {
+ ConvertFloorOperator(node, model);
+ } else if (node.op() == "Gather") {
+ ConvertGatherOperator(node, model);
+ } else if (node.op() == "ResizeBilinear") {
+ ConvertResizeBilinearOperator(node, model);
+ } else if (node.op() == "BatchNormWithGlobalNormalization") {
+ ConvertBatchNormWithGlobalNormalizationOperator(node, model);
+ } else if (node.op() == "FusedBatchNorm") {
+ ConvertFusedBatchNormOperator(node, model);
+ } else if (node.op() == "SpaceToBatchND") {
+ ConvertSpaceToBatchNDOperator(node, model);
+ } else if (node.op() == "BatchToSpaceND") {
+ ConvertBatchToSpaceNDOperator(node, model);
+ } else if (node.op() == "Mean") {
+ ConvertMeanOperator(node, model);
+ } else if (node.op() == "Svdf") {
+ ConvertSvdfOperator(node, model);
+ } else {
+ ConvertUnsupportedOperator(node, model);
+ }
+ }
+
+ StripCaretFromArrayNames(model);
+ AddExtraOutputsFedIntoOtherOps(model);
+ FixNoMissingArray(model);
+ FixNoOrphanedArray(model);
+ FixOperatorOrdering(model);
+ CheckInvariants(*model);
+
+ // if rnn state arrays are constant, make them transient
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ model->GetArray(rnn_state.state_array()).buffer = nullptr;
+ }
+
+ return std::unique_ptr<Model>(model);
+}
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(
+ const ModelFlags& model_flags, const string& input_file_contents) {
+ std::unique_ptr<GraphDef> tf_graph(new GraphDef);
+ CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
+
+ std::unique_ptr<GraphDef> pruned_graph =
+ MaybeReplaceCompositeSubgraph(*tf_graph);
+ if (pruned_graph) {
+ tf_graph = std::move(pruned_graph);
+ }
+ return ImportTensorFlowGraphDef(model_flags, *tf_graph);
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
new file mode 100644
index 0000000000..d2eb423ca4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+
+#include <memory>
+#include <string>
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace toco {
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(
+ const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def);
+
+std::unique_ptr<Model> ImportTensorFlowGraphDef(
+ const ModelFlags& model_flags, const string& input_file_contents);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
new file mode 100644
index 0000000000..d992f8458f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -0,0 +1,1372 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+enum class OperatorType {
+ kNone,
+ // General-purpose neural network operators.
+ kAdd,
+ kAveragePool,
+ kBatchNormalization,
+ kConv,
+ kConcatenation,
+ kDepthwiseConv,
+ kDepthToSpace,
+ kSpaceToDepth,
+ kDequantize,
+ kDiv,
+ kFullyConnected,
+ kL2Normalization,
+ kL2Pool,
+ kLstmCell,
+ kLocalResponseNormalization,
+ kLogistic,
+ kMaxPool,
+ kFakeQuant,
+ kMul,
+ kRelu,
+ kRelu1,
+ kRelu6,
+ kSoftmax,
+ kSub,
+ kTanh,
+ kCast,
+ kFloor,
+ kGather,
+ kResizeBilinear,
+ kSpaceToBatchND,
+ kBatchToSpaceND,
+ kPad,
+ kStridedSlice,
+ kSlice,
+ kSqueeze,
+ kMean,
+ // The SVDF Op is a decomposition of a densely connected Op into
+ // low rank filters. For details:
+ // https://research.google.com/pubs/pub43813.html
+ kSvdf,
+ // Special operators used for importing TensorFlow nodes.
+ // The general intent is to have some graph transformation either
+ // drop them or rewrite them as general-purpose operators.
+ kTensorFlowAll,
+ kTensorFlowAssert,
+ kTensorFlowConcat,
+ kTensorFlowConcatV2,
+ kTensorFlowGreater,
+ kTensorFlowGreaterEqual,
+ kTensorFlowIdentity,
+ kTensorFlowLess,
+ kTensorFlowLessEqual,
+ kTensorFlowMax,
+ kTensorFlowMaximum,
+ kTensorFlowMin,
+ kTensorFlowMinimum,
+ kTensorFlowMatMul,
+ kTensorFlowMerge,
+ kTensorFlowReshape,
+ kTensorFlowRsqrt,
+ kTensorFlowShape,
+ kTensorFlowSplit,
+ kTensorFlowSqrt,
+ kTensorFlowSquare,
+ kTensorFlowSum,
+ kTensorFlowSwitch,
+ kTensorFlowTile,
+ // An unsupported TF operation. It's only needed to be able to represent TF
+ // graph internally and is expected to be dropped by graph transformations.
+ kTensorFlowUnsupported,
+ // Finally, TensorFlow uses different conventions for axes ordering,
+ // see AxesOrder, and this cannot always be resolved at the time of importing
+ // nodes, as TensorFlow parameters may be constant-expression subgraphs
+ // instead of being given as plain constant arrays. So we need to insert
+ // special nodes in the graph to shuffle axes.
+ kReorderAxes,
+};
+
+// Helper to deal with TensorFlow arrays using a different ordering of
+// dimensions
+// ("axes") than our own.
+// TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes,
+// we should have associative arrays mapping symbolic axes identifiers (like
+// "output_depth") to dimensions. We would then not need this anymore.
+enum class AxesOrder {
+ kOneAxis, // one-dimensional array, one unique axis.
+ kCR, // column-major matrix storage order. Our standard.
+ kRC, // row-major matrix storage order. TensorFlow default.
+ kOHWI, // Our standard for conv weights
+ kHWIO, // TensorFlow conv weights
+ k1HWO, // Our standard for DepthwiseConv weights
+ kHWIM, // TensorFlow DepthwiseConv weights
+ kNHWC, // TensorFlow activations
+};
+
+// The type of the scalars in an array.
+// Note that that does not by itself tell whether the values in the array are
+// real (are literally interpreted as real numbers) or quantized (only acquire
+// a meaning as real numbers in conjuction with QuantizationParams).
+//
+// In practice though:
+// float values are always real
+// uint8 values are always quantized
+// int32 values are either real or quantized (depending on whether
+// QuantizationParams are present).
+// other types are unused at the moment.
+//
+// kNone means that we don't know the data type yet, or that we don't care
+// because we'll be dropping the array anyway (e.g. some exotic array types
+// may be involved only in debug-only subgraphs that we may not be interested
+// in actually supporting).
+enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 };
+
+// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
+template <ArrayDataType A>
+struct DataTypeImpl {};
+template <>
+struct DataTypeImpl<ArrayDataType::kNone> {
+ typedef int Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kBool> {
+ typedef bool Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kFloat> {
+ typedef float Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kUint8> {
+ typedef uint8 Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kInt32> {
+ typedef int32 Type;
+};
+template <>
+struct DataTypeImpl<ArrayDataType::kInt64> {
+ typedef int64 Type;
+};
+
+template <ArrayDataType A>
+using DataType = typename DataTypeImpl<A>::Type;
+
+// Base class for type-specific buffer types.
+struct GenericBuffer {
+ // Non-default-constructible: only ArrayDataType-specific subclass
+ // objects may be constructed.
+ GenericBuffer() = delete;
+ // Non-copyable-or-movable: we should only store pointers-to-Buffer
+ // in containers, not Operators themselves, so there should be no
+ // copy or move.
+ GenericBuffer(const GenericBuffer&) = delete;
+ GenericBuffer(const GenericBuffer&&) = delete;
+
+ // We need a virtual destructor so we can store pointers-to-Buffer
+ // in containers and have the containers call the right subclass destructor.
+ virtual ~GenericBuffer() {}
+
+ const ArrayDataType type;
+
+ protected:
+ // Constructor used by subclasses for specific ArrayDataType's.
+ explicit GenericBuffer(ArrayDataType t) : type(t) {}
+};
+
+// Type-specific buffer, containing type-specific storage.
+template <ArrayDataType A>
+struct Buffer : GenericBuffer {
+ Buffer() : GenericBuffer(A) {}
+
+ std::vector<DataType<A>> data;
+};
+
+// Base class for all operator classes.
+struct Operator {
+ // Non-default-constructible: only OperatorType-specific subclass
+ // objects may be constructed.
+ Operator() = delete;
+ // Non-copyable-or-movable: we should only store pointers-to-Operator
+ // in containers, not Operators themselves, so there should be no
+ // copy or move.
+ Operator(const Operator&) = delete;
+ Operator(const Operator&&) = delete;
+
+ // We need a virtual destructor so we can store pointers-to-Operator
+ // in containers and have the containers call the right subclass destructor.
+ virtual ~Operator() {}
+
+ // The specific type of operator. Corresponds 1:1 to subclasses.
+ const OperatorType type;
+
+ // The activation function that may be fused into this operator,
+ // or None if no activation function is fused.
+ FusedActivationFunctionType fused_activation_function;
+
+ // Input arrays: either activation arrays or constant array parameters.
+ // We refer to them by their name, not by their address; the mapping of
+ // names to addresses is given by the Model, which owns both Operator's and
+ // Array's. Thus, an Operator on its own doesn't contain much information,
+ // it is meant to be used in conjunction with the Model that owns it.
+ std::vector<string> inputs;
+
+ // Output activation arrays. Same comments as for inputs apply here too.
+ std::vector<string> outputs;
+
+ // If true, the array has more outputs than are listed in the 'outputs'
+ // member. These need to be resolved by some graph transformation.
+ // This flag is only here to indicate that an operator should not be
+ // discarded as unused, even if from its 'outputs' member alone it
+ // looks unused.
+ bool unresolved_outputs = false;
+
+ protected:
+ // Constructor used by subclasses for specific OperatorType's.
+ explicit Operator(OperatorType t)
+ : type(t),
+ fused_activation_function(FusedActivationFunctionType::kNone) {}
+};
+
+// Padding types for Conv-like operators. This is how padding is typically
+// specified in model files. But for inference, we will need to resolve this
+// to a FixedPadding, see below.
+enum class PaddingType { kNone, kSame, kValid };
+
+// Padding as resolved for a specific layer shape, as needed for inference.
+// For a given layer shape, a given padding type will resolve to a choice of
+// a number of padding rows and columns, which we call the padding height and
+// width respectively.
+struct FixedPadding {
+ int width = 0;
+ int height = 0;
+};
+
+// "Universal" padding struct containing both a generic PaddingType (as
+// represented in a model file), and a FixedPadding (as needed for inference).
+// The latter is resolved during the PropagateFixedSizes pass.
+struct Padding {
+ FixedPadding& GetOrCreateFixedPadding() {
+ if (!fixed) {
+ FixedPadding* ptr = new FixedPadding;
+ fixed = std::unique_ptr<FixedPadding>(ptr);
+ }
+ return *fixed;
+ }
+
+ Padding() : type(PaddingType::kNone) {}
+ PaddingType type;
+ std::unique_ptr<FixedPadding> fixed;
+};
+
+// "Convolutional" layer, as represented in model files.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the Conv weights
+// inputs[2]: optional: the bias vector, specifying the biases for each output
+// channel.
+//
+// Outputs:
+// outputs[0]: required: the output activations array
+// outputs[1]: optional: the intermediate array of im2col-replicated input
+// activations. Present when targeting implementations
+// of Conv layers as Im2col+GEMM.
+//
+// TensorFlow equivalent: Conv2D
+struct ConvOperator : Operator {
+ ConvOperator() : Operator(OperatorType::kConv) {}
+ Padding padding;
+ int stride_width = 0;
+ int stride_height = 0;
+};
+
+// Depthwise-separable convolution operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the DepthwiseConv weights
+// inputs[2]: optional: the bias vector, specifying the biases for each output
+// channel.
+//
+// TensorFlow equivalent: DepthwiseConv2dNative
+struct DepthwiseConvOperator : Operator {
+ DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int depth_multiplier = 0;
+};
+
+// Depth-to-space transform operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+//
+// TensorFlow equivalent: DepthToSpace
+struct DepthToSpaceOperator : Operator {
+ DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {}
+ int block_size = 0;
+};
+
+// Space-to-depth transform operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+//
+// TensorFlow equivalent: SpaceToDepth
+struct SpaceToDepthOperator : Operator {
+ SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {}
+ int block_size = 0;
+};
+
+// Fully-connected operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the FullyConnected weights
+// inputs[2]: optional: the bias vector, specifying the biases for each output
+// channel.
+//
+// TensorFlow equivalent: a pair consisting of a Reshape node reshaping the
+// input activations as a matrix, followed by a MatMul node.
+struct FullyConnectedOperator : Operator {
+ FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
+};
+
+// Dequantization operator, converting a quantized array of integers with
+// quantization parameters specifying how these integers correspond to real
+// numbers
+// (see QuantizationParams) to an output activations array of floating-point
+// values.
+//
+// In floating-point image models, there is typically a Dequantization operator
+// at the very beginning, converting the input image RGB data, consisting of
+// uint8 integer values, to floating-point input activations. That is where
+// image model parameters such as "mean_value" and "std_value" are typically
+// handled.
+//
+// This is the only operator type that converts from quantized to
+// floating-point,
+// and there is at the moment no operator type at all to convert from
+// floating-point
+// to quantized. Every other operator does either float->float or
+// quantized->quantized.
+//
+// Inputs:
+// inputs[0]: required: the input quantized activations array
+//
+// TensorFlow equivalent: Dequantize
+struct DequantizeOperator : Operator {
+ DequantizeOperator() : Operator(OperatorType::kDequantize) {}
+};
+
+// Batch-normalization operator.
+//
+// We only support batch-normalization using pre-learned moments, so this is
+// just
+// computing (input - mean) * multiplier + offset. As such, this can be
+// expressed as a combination of Add and Mul nodes, and indeed this is how
+// we break it down during tooling for the purpose of fusing it into
+// other operators.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+// inputs[1]: required: the learned mean array
+// inputs[2]: required: the learned multiplier array
+// inputs[3]: required: the learned offset array
+//
+// TensorFlow equivalent: a combination of Add and Mul nodes
+struct BatchNormalizationOperator : Operator {
+ BatchNormalizationOperator()
+ : Operator(OperatorType::kBatchNormalization),
+ global_normalization(false) {}
+ bool global_normalization;
+};
+
+// L2-normalization operator.
+//
+// Inputs:
+// inputs[0]: required: the input activations array
+//
+// TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented
+// by a sub-graph of operators implementing L2-normalization
+// from lower-level arithmetic nodes; during tooling, we identify such
+// sub-graphs
+// and replace them by L2NormalizationOperator's. See IdentifyL2Normalization.
+struct L2NormalizationOperator : Operator {
+ L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {}
+};
+
+// LSTM Cell operator.
+//
+// Inputs:
+// inputs[0]: required: the input data array
+// inputs[1]: required: the previous output activations array
+// inputs[2]: required: the learned weights array
+// inputs[3]: required: the learned biases array
+// inputs[4]: required: the previous output state
+// outputs[0]: required: the output activations array
+// outputs[1]: required: the new state array
+//
+// TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented
+// with a sub-graph of lower-level arithmetic nodes; during tooling, we identify
+// such sub-graphs and replace them with LstmCells. See IdentifyLstmCell().
+struct LstmCellOperator : Operator {
+ enum Inputs {
+ DATA_INPUT = 0,
+ PREV_ACTIV_INPUT = 1,
+ WEIGHTS_INPUT = 2,
+ BIASES_INPUT = 3,
+ PREV_STATE_INPUT = 4,
+ NUM_INPUTS = 5
+ };
+ enum Outputs {
+ ACTIV_OUTPUT = 0,
+ STATE_OUTPUT = 1,
+ CONCAT_TEMP = 2,
+ ACTIV_TEMP = 3,
+ NUM_OUTPUTS = 4
+ };
+ LstmCellOperator() : Operator(OperatorType::kLstmCell) {}
+};
+
+// Element-wise multiplication operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Mul
+struct MulOperator : Operator {
+ MulOperator() : Operator(OperatorType::kMul) {}
+};
+
+// Element-wise Relu operator:
+// x -> max(0, x)
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Relu
+struct ReluOperator : Operator {
+ ReluOperator() : Operator(OperatorType::kRelu) {}
+};
+
+// Element-wise Relu1 operator:
+// x -> min(max(x, -1), 1)
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: none. We can construct the operator with Minimum
+// and Maximum operations
+struct Relu1Operator : Operator {
+ Relu1Operator() : Operator(OperatorType::kRelu1) {}
+};
+
+// Element-wise Relu6 operator:
+// x -> max(0, min(6, x))
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Relu6
+struct Relu6Operator : Operator {
+ Relu6Operator() : Operator(OperatorType::kRelu6) {}
+};
+
+// Element-wise Logistic operator:
+// x -> Logistic(x) = 1 / (1 + exp(-x))
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sigmoid
+struct LogisticOperator : Operator {
+ LogisticOperator() : Operator(OperatorType::kLogistic) {}
+};
+
+// Element-wise Tanh operator:
+// x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Tanh
+struct TanhOperator : Operator {
+ TanhOperator() : Operator(OperatorType::kTanh) {}
+};
+
+// Element-wise addition operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Add
+struct AddOperator : Operator {
+ AddOperator() : Operator(OperatorType::kAdd) {}
+};
+
+// Concatenation operator: concatenates its inputs
+// along the concat_dim dimension.
+//
+// Inputs: this operator accepts any number >= 1 of inputs.
+// inputs[i]: the i-th array to concatenate.
+//
+// TensorFlow equivalent: Concat.
+struct ConcatenationOperator : Operator {
+ ConcatenationOperator() : Operator(OperatorType::kConcatenation) {}
+ int concat_dim = 0;
+};
+
+// Reordering dimensions. Used only during tooling to transform graphs from
+// the TensorFlow format.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: none. This is only useful to convert between formats.
+struct ReorderAxesOperator : Operator {
+ ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {}
+ AxesOrder input_axes_order;
+ AxesOrder output_axes_order;
+};
+
+// Average-pooling operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: AveragePool
+struct AveragePoolOperator : Operator {
+ AveragePoolOperator() : Operator(OperatorType::kAveragePool) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int kheight = 0;
+ int kwidth = 0;
+};
+
+// Local response normalization operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: LRN
+struct LocalResponseNormalizationOperator : Operator {
+ LocalResponseNormalizationOperator()
+ : Operator(OperatorType::kLocalResponseNormalization) {}
+
+ int range = 0;
+ float bias = 0.f;
+ float alpha = 0.f;
+ float beta = 0.f;
+};
+
+// Max-pooling operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: MaxPool
+struct MaxPoolOperator : Operator {
+ MaxPoolOperator() : Operator(OperatorType::kMaxPool) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int kheight = 0;
+ int kwidth = 0;
+};
+
+// L2-pooling operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt.
+struct L2PoolOperator : Operator {
+ L2PoolOperator() : Operator(OperatorType::kL2Pool) {}
+ Padding padding;
+ int stride_height = 0;
+ int stride_width = 0;
+ int kheight = 0;
+ int kwidth = 0;
+};
+
+// The expected [min, max] range of values in a given array.
+// Used for quantization only.
+// This information typically comes from special nodes found in quantized
+// models,
+// see FakeQuantOperator, and is used during quantization to resolve
+// actual quantization parameters (see QuantizationParams).
+struct MinMax {
+ double min = 0.;
+ double max = 0.;
+};
+
+inline bool operator==(const MinMax& m1, const MinMax& m2) {
+ return m1.min == m2.min && m1.max == m2.max;
+}
+
+// Fake-quantization operator. This does two things:
+// - Annotate its input and output arrays with MinMax information,
+// - Arithmetic-wise, this operator rounds incoming activation values
+// to the nearest representable value on the scale of 256
+// values from the min to the max value dictated by its MinMax info.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: optional: the 'min' value, if it has not yet been resolved
+// to a constant.
+// inputs[2]: optional: the 'max' value, if it has not yet been resolved
+// to a constant.
+//
+// TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs.
+struct FakeQuantOperator : Operator {
+ FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {}
+ std::unique_ptr<MinMax> minmax;
+};
+
+// Element-wise division operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Div
+struct DivOperator : Operator {
+ DivOperator() : Operator(OperatorType::kDiv) {}
+};
+
+// Element-wise identity (x->x) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Identity
+struct TensorFlowIdentityOperator : Operator {
+ TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {}
+};
+
+// General matrix multiplication operator. We don't want to support general
+// matrix multiplication at inference time, so we resolve it during tooling
+// to more specific operator types, namely, FullyConnected.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side matrix
+// inputs[1]: required: the right-hand side matrix
+//
+// TensorFlow equivalent: MatMul
+struct TensorFlowMatMulOperator : Operator {
+ TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {}
+};
+
+// Padding operator. Pads a tensor with zeros.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the padding array
+//
+// This operation pads a `input` with zeros according to the `paddings` you
+// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
+// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many zeros to add before the contents of `input` in that dimension, and
+// `paddings[D, 1]` indicates how many zeros to add after the contents of
+// `input` in that dimension.
+//
+// TensorFlow equivalent: Pad
+struct PadOperator : Operator {
+ PadOperator() : Operator(OperatorType::kPad) {}
+
+ std::vector<int> left_padding;
+ std::vector<int> right_padding;
+};
+
+// Strided slice operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: StridedSlice
+struct StridedSliceOperator : Operator {
+ StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {}
+
+ std::vector<int> start_indices;
+ std::vector<int> stop_indices;
+ std::vector<int> strides;
+
+ int begin_mask;
+ int ellipsis_mask;
+ int end_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+};
+
+// Reshaping operator, reshaping its input array to a two-dimensional shape
+// (a "matrix"). This is used in the TensorFlow format, in conjunction with
+// MatMul nodes, to implement fully-connected layers.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Reshape --- except that we only support a special case
+// here, where the output shape is a matrix (2D) shape.
+struct TensorFlowReshapeOperator : Operator {
+ TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {}
+ std::vector<int> shape;
+};
+
+// Removes dimensions of size 1 from the shape of a tensor.
+// https://www.tensorflow.org/api_docs/python/tf/squeeze
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Squeeze
+struct SqueezeOperator : Operator {
+ SqueezeOperator() : Operator(OperatorType::kSqueeze) {}
+
+ std::vector<int> squeeze_dims;
+};
+
+// Element-wise reciprocal-square-root (x^-0.5) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Rsqrt
+struct TensorFlowRsqrtOperator : Operator {
+ TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {}
+};
+
+// Shape operator. Extracts the shape of the tensor.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// This operation outputs a 1-D integer tensor representing the shape of
+// the input.
+//
+// TensorFlow equivalent: Shape. We currently assume that the output is int32
+// and not int64. The output type could be stored herein.
+struct TensorFlowShapeOperator : Operator {
+ TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {}
+};
+
+// Element-wise square-root (x^0.5) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sqrt
+struct TensorFlowSqrtOperator : Operator {
+ TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {}
+};
+
+// Element-wise square (x*x) operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Square
+struct TensorFlowSquareOperator : Operator {
+ TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {}
+};
+
+// Element-wise subtraction operator.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Sub
+struct SubOperator : Operator {
+ SubOperator() : Operator(OperatorType::kSub) {}
+};
+
+// Global sum reduction: computes the sum of all of entries in the input array.
+// Thus the output is "0-dimensional": it consists of a single scalar value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sum --- except that we only support the special case
+// of global reduction across all dimensions.
+struct TensorFlowSumOperator : Operator {
+ TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {}
+};
+
+// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+struct TensorFlowTileOperator : Operator {
+ TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {}
+};
+
+// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
+struct SliceOperator : Operator {
+ SliceOperator() : Operator(OperatorType::kSlice) {}
+
+ std::vector<int> begin;
+ std::vector<int> size;
+};
+
+// TensorFlow Split equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+struct TensorFlowSplitOperator : Operator {
+ TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {}
+ int num_split = 0;
+};
+
+// TensorFlow Concat equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Concretely, once the concat dim becomes known, if it is the depth
+// dimension then we can change this op into a DepthConcatenation op.
+// Otherwise, we hope for some other graph transformation to drop this node.
+struct TensorFlowConcatOperator : Operator {
+ TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {}
+};
+
+// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Concretely, once the concat dim becomes known, if it is the depth
+// dimension then we can change this op into a DepthConcatenation op.
+// Otherwise, we hope for some other graph transformation to drop this node.
+struct TensorFlowConcatV2Operator : Operator {
+ TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {}
+};
+
+// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
+//
+// Inputs: this operator accepts any number >= 1 of inputs.
+// inputs[i]: the i-th array to merge.
+//
+// It is expected that graph transformations will drop all but exactly one
+// of the inputs, at which point the Merge node will be equivalent to an
+// Identity node forwarding the remaining input.
+//
+// Note: We do not currently support runtime control flow: we only support
+// control flow that can be resolved at tooling time (independently of input
+// activations).
+struct TensorFlowMergeOperator : Operator {
+ TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {}
+};
+
+// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the boolean predicate, given as an array of size 1
+// and of type kBool, will determine which output gets selected.
+//
+// Outputs: a TensorFlow Switch node always has exactly two outputs. Depending
+// on the boolean value that the input predicate resolves to (see note below),
+// one or the other of the outputs will be 'selected': the input array will be
+// forwarded to the 'selected output' as if by a Identity node, while the other
+// output will be discarded, and any graph edge connecting that discarded output
+// will be dropped. The rule for selecting outputs is as follows:
+// outputs[0] will be selected if the input predicate resolves to 'true'.
+// outputs[1] will be selected if the input predicate resolves to 'false'.
+//
+// Note: We do not currently support runtime control flow: we only support
+// control flow that can be resolved at tooling time (independently of input
+// activations).
+struct TensorFlowSwitchOperator : Operator {
+ TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {}
+};
+
+// TensorFlow All equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowAllOperator : Operator {
+ TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {}
+};
+
+// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, we just drop Assert nodes.
+struct TensorFlowAssertOperator : Operator {
+ TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {}
+};
+
+// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowLessOperator : Operator {
+ TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {}
+};
+
+// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowLessEqualOperator : Operator {
+ TensorFlowLessEqualOperator()
+ : Operator(OperatorType::kTensorFlowLessEqual) {}
+};
+
+// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowGreaterOperator : Operator {
+ TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {}
+};
+
+// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowGreaterEqualOperator : Operator {
+ TensorFlowGreaterEqualOperator()
+ : Operator(OperatorType::kTensorFlowGreaterEqual) {}
+};
+
+// Global max reduction: computes the max of all of entries in the input array.
+// Thus the output is "0-dimensional": it consists of a single scalar value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Max --- except that we only support the special case
+// of global reduction across all dimensions.
+struct TensorFlowMaxOperator : Operator {
+ TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {}
+};
+
+// Global min reduction: computes the min of all of entries in the input array.
+// Thus the output is "0-dimensional": it consists of a single scalar value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Min --- except that we only support the special case
+// of global reduction across all dimensions.
+struct TensorFlowMinOperator : Operator {
+ TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {}
+};
+
+// Element-wise maximum operator. Currently it only supports scalar as
+// the second operand.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Maximum
+struct TensorFlowMaximumOperator : Operator {
+ TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {}
+};
+
+// Element-wise minimum operator. Currently it only supports scalar as
+// the second operand.
+//
+// Inputs:
+// inputs[0]: required: the left-hand side array
+// inputs[1]: required: the right-hand side array
+//
+// TensorFlow equivalent: Minimum
+struct TensorFlowMinimumOperator : Operator {
+ TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {}
+};
+
+// General TF operation, unsupported by tf.mini. Expected to be dropped by
+// graph transformations.
+struct TensorFlowUnsupportedOperator : Operator {
+ TensorFlowUnsupportedOperator()
+ : Operator(OperatorType::kTensorFlowUnsupported) {}
+
+ // The original TF operation type. Used for diagnostic purposes.
+ string tensorflow_op;
+ // A serialized tensorflow::NodeDef string.
+ string tensorflow_node_def;
+ // A boolean indicating if the unsupported op should be treated as quantized.
+ bool quantized = false;
+ // Output data types
+ std::vector<ArrayDataType> output_data_types;
+};
+
+// Softmax activation function.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Softmax
+struct SoftmaxOperator : Operator {
+ SoftmaxOperator() : Operator(OperatorType::kSoftmax) {}
+ float beta = 0.f;
+};
+
+// Cast operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Cast
+struct CastOperator : Operator {
+ CastOperator() : Operator(OperatorType::kCast) {}
+ ArrayDataType src_data_type = ArrayDataType::kNone;
+ ArrayDataType dst_data_type = ArrayDataType::kNone;
+};
+
+// Floor operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Floor
+struct FloorOperator : Operator {
+ FloorOperator() : Operator(OperatorType::kFloor) {}
+};
+
+// Gather operator. It gathers slices from params according to indices.
+// Only 1-D indices are supported at the moment.
+//
+// Inputs:
+// inputs[0]: required: the params array
+// inputs[1]: required: the indices to gather
+//
+// TensorFlow equivalent: Gather
+struct GatherOperator : Operator {
+ GatherOperator() : Operator(OperatorType::kGather) {}
+ int input_rank;
+};
+
+// ResizeBilinear operator. It resizes input images with bilinear interpolation.
+// It does not support align_corners at the moment.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the new image size
+//
+// TensorFlow equivalent: ResizeBilinear
+struct ResizeBilinearOperator : Operator {
+ ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {}
+};
+
+// SpaceToBatchND operator. It divides spatial dimensions into a grid of
+// blocks and interleaves these blocks with the batch dimension. Currently,
+// only 2-d blocks are supported.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the block shape
+// inputs[2]: required: the paddings
+//
+// TensorFlow equivalent: SpaceToBatchND
+struct SpaceToBatchNDOperator : Operator {
+ SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {}
+};
+
+// BatchToSpaceND operator. Rearranges data from batch into blocks of
+// spatial data. Currently, only 2-d blocks are supported. Cropping is not
+// supported, either, and the crops array should be all zero.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the block shape
+// inputs[2]: required: the crops
+//
+// TensorFlow equivalent: BatchToSpaceND
+struct BatchToSpaceNDOperator : Operator {
+ BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {}
+};
+
+// Mean operator.
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Mean
+struct MeanOperator : Operator {
+ MeanOperator() : Operator(OperatorType::kMean) {}
+
+ std::vector<int> reduction_indices;
+};
+
+// Svdf operator:
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: weights_feature
+// inputs[2]: required: weights_time
+// inputs[3]: optional: bias
+struct SvdfOperator : Operator {
+ SvdfOperator() : Operator(OperatorType::kSvdf) {}
+ int rank;
+};
+
+// Alloc's are used for transient arrays only. An Alloc specifies which interval
+// of the "transient_data" workspace buffer passed to inference functions, is to
+// be used for the transient array at hand. The 'start' and 'end' values are
+// offsets from the start of the workspace buffer, expressed in bytes.
+struct Alloc {
+ int start = 0;
+ int end = 0;
+};
+
+inline bool operator<(const Alloc& a, const Alloc& b) {
+ return a.start < b.start;
+}
+
+// Quantization parameters, determining the mapping of quantized values
+// to real values (i.e. determining how quantized values are mathematically
+// interpreted).
+//
+// The correspondence is as follows:
+//
+// real_value = scale * (quantized_value - zero_point);
+//
+// In other words, zero_point designates which quantized value corresponds to
+// the real 0 value, and scale designates the difference between the real values
+// corresponding to consecutive quantized values differing by 1.
+struct QuantizationParams {
+ int32 zero_point = 0;
+ double scale = 0.;
+};
+
+class Shape {
+ public:
+ // For Shape, we stick to half-way encapsulation for now:
+ // we hide the raw dims_ member, but expose it raw by accessors
+ // because from some brainstorming, it's not at all easy to
+ // anticipate which flavor of more hermetic encapsulation would
+ // actually buy us future-proof-ness without being needlessly
+ // cumbersome.
+ Shape() {}
+ Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
+
+ void ReplaceDims(std::initializer_list<int> dim_list) {
+ dims_ = std::vector<int>(dim_list);
+ }
+
+ const std::vector<int>& dims() const { return dims_; }
+ std::vector<int>* mutable_dims() { return &dims_; }
+ const int dimensions_count() const { return dims_.size(); }
+
+ // We still have that one convenience accessor to avoid
+ // the awkward double bracket issue: shape.dims()[i].
+ int dims(int i) const { return dims_[i]; }
+
+ bool operator==(const Shape& comp) const {
+ return (this->dims_ == comp.dims());
+ }
+
+ bool operator!=(const Shape& comp) const { return !((*this) == comp); }
+
+ private:
+ std::vector<int> dims_;
+};
+
+// Array represents an array (either a constant parameter array or an
+// activations array) in a Model.
+struct Array {
+ template <ArrayDataType A>
+ const Buffer<A>& GetBuffer() const {
+ DCHECK(buffer);
+ DCHECK(buffer->type == A);
+ return *static_cast<const Buffer<A>*>(buffer.get());
+ }
+ template <ArrayDataType A>
+ Buffer<A>& GetMutableBuffer() {
+ if (!buffer) {
+ Buffer<A>* ptr = new Buffer<A>;
+ buffer = std::unique_ptr<GenericBuffer>(ptr);
+ }
+ DCHECK(buffer);
+ DCHECK(buffer->type == A);
+ return *static_cast<Buffer<A>*>(buffer.get());
+ }
+ Alloc& GetOrCreateAlloc() {
+ if (!alloc) {
+ alloc = std::unique_ptr<Alloc>(new Alloc);
+ }
+ return *alloc;
+ }
+ MinMax& GetOrCreateMinMax() {
+ if (!minmax) {
+ minmax = std::unique_ptr<MinMax>(new MinMax);
+ }
+ return *minmax;
+ }
+ MinMax& GetMinMax() const {
+ DCHECK(minmax);
+ return *minmax;
+ }
+ QuantizationParams& GetOrCreateQuantizationParams() {
+ if (!quantization_params) {
+ quantization_params =
+ std::unique_ptr<QuantizationParams>(new QuantizationParams);
+ }
+ return *quantization_params;
+ }
+ QuantizationParams& GetQuantizationParams() const {
+ DCHECK(quantization_params);
+ return *quantization_params;
+ }
+
+ // The data type of the actual elements of this array, that is:
+ // - If there is a buffer (see 'buffer' member), it must be of the same
+ // type.
+ // - If there is no buffer, meaning that this is a runtime (i.e. activations)
+ // array, then this specifies the type of elements that there will be
+ // at runtime.
+ //
+ // Note that this only specifies the storage type of elements; this does
+ // not specify whether these are to be treated as 'real' or 'quantized'
+ // values.
+ // That is decided by whether the 'quantization_params' member is null.
+ ArrayDataType data_type = ArrayDataType::kNone;
+ // The final value that data_type should have at the end of graph
+ // transformations
+ ArrayDataType final_data_type = ArrayDataType::kNone;
+ // The dimensions of this array --- this specifies both sizes and strides
+ // (the storage layout).
+ //
+ // Issues with shape handling that remain include:
+ // - No way to distinguish between 0-dimensional dims and missing dims.
+ // - No way to describe dims that may be runtime-variable.
+ // - Addressing of dims by integer index differs in different graph formats
+ // (TensorFlow vs. other frameworks vs. what we have informally grown
+ // within toco).
+ // This is currently quite messy; see ReorderAxesOperator which is how we
+ // bridge some of these discrepancies at the moment. This is overdue for
+ // a redesign; I'm thinking that it would be nice to have more flexible
+ // dims that allow mapping 1:1, cleanly, dims as they are in various
+ // formats,
+ // then explicitly convert between different conventions.
+
+ // Proto-style accessors
+ bool has_shape() const { return array_shape != nullptr; }
+ const Shape& shape() const {
+ CHECK(has_shape());
+ return *array_shape;
+ }
+ Shape* mutable_shape() {
+ if (!array_shape) {
+ array_shape.reset(new Shape);
+ }
+ return array_shape.get();
+ }
+ void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; }
+ void clear_shape() { array_shape = nullptr; }
+
+ // The constant buffer backing this array. This is non-null if and only if
+ // this is a constant parameter array. Conversely, this is null for
+ // activations arrays.
+ //
+ // Note that this buffer is pure storage. In the case of quantized values,
+ // it only stores the quantized values, it does not know by itself about the
+ // quantization parameters necessary to interprete these values, that is
+ // in the separate 'quantization_params' field. In fact, this 'buffer' field
+ // does no even know whether values are quantized. It only has a data_type,
+ // which must equal the 'data_type' member here, and which only describes
+ // the storage type of element, does not tell whether they are quantized i.e.
+ // whether they are to be interpreted with quantization_params.
+ std::unique_ptr<GenericBuffer> buffer;
+ // Only for activation arrays (i.e. when 'buffer' is null).
+ // Only for code generation.
+ //
+ // Describes the allocation of this array within the workspace buffer
+ // allocated
+ // for all transient arrays.
+ std::unique_ptr<Alloc> alloc;
+ // Describes the [min, max] range of values
+ // to be assumed when determining quantization_params.
+ //
+ // Only used for quantization. In fact, only used for determining
+ // quantization_params.
+ //
+ // Used for both constant arrays (those having a 'buffer') and non-constant
+ // arrays (activations). Indeed, it is important to use the same min-max range
+ // as was used during training, even if that min-max range is slightly wrong
+ // w.r.t. actual buffer elements. Doing otherwise would defeat the point of
+ // re-training for quantization.
+ std::unique_ptr<MinMax> minmax;
+ // Quantization parameters. The non-null-ness of this pointer is what
+ // defines whether this array is quantized or not.
+ //
+ // If this is non-null, then these quantization parameters are to be used
+ // to assign a meaning as real numbers to the elements of this array.
+ std::unique_ptr<QuantizationParams> quantization_params;
+
+ private:
+ std::unique_ptr<Shape> array_shape;
+};
+
+// Our Model struct, represents an entire model (our "top-level" struct).
+// Owns everything.
+struct Model {
+ Array& GetArray(const string& name) const {
+ DCHECK(arrays.count(name));
+ return *arrays.at(name);
+ }
+ Array& GetOrCreateArray(const string& name) {
+ if (!arrays.count(name)) {
+ Array* ptr = new Array;
+ arrays[name] = std::unique_ptr<Array>(ptr);
+ }
+ Array& result = GetArray(name);
+ return result;
+ }
+
+ // The list of operators. Notice how it's a list of unique_ptr's, implying
+ // that the Model is what owns Operator's and keeps them alive.
+ std::vector<std::unique_ptr<Operator>> operators;
+ // The associative array mapping names to Array's.
+ // Notice how it's a container of unique_ptr's, implying
+ // that the Model is what owns Array's and keeps them alive.
+ // The Operator's refer to these Array's by their name strings, not by their
+ // addresses. See Operator::inputs, Operator::outputs.
+ std::unordered_map<string, std::unique_ptr<Array>> arrays;
+ // Generic flags, a place where we combine information passed to us via
+ // command-line parameters (e.g. --input_width=N) with information that
+ // we may or may not find in the input model file.
+ ModelFlags flags;
+ // For code-generation only: required size of the transient_data buffer
+ std::size_t transient_data_size = 0;
+ // For code-generation only: required alignment of the transient_data buffer
+ std::size_t transient_data_alignment = 0;
+};
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
new file mode 100644
index 0000000000..699c95753f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -0,0 +1,374 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+// "batch" flag only exists internally
+#ifdef PLATFORM_GOOGLE
+#include "base/commandlineflags.h"
+#endif
+
+namespace toco {
+
+bool ParseModelFlagsFromCommandLineFlags(
+ int* argc, char* argv[], string* msg,
+ ParsedModelFlags* parsed_model_flags_ptr) {
+ ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
+ using tensorflow::Flag;
+ std::vector<tensorflow::Flag> flags = {
+ Flag("input_array", parsed_flags.input_array.bind(),
+ parsed_flags.input_array.default_value(),
+ "Name of the input array. If not specified, will try to read "
+ "that information from the input file."),
+ Flag("input_arrays", parsed_flags.input_arrays.bind(),
+ parsed_flags.input_arrays.default_value(),
+ "Names of the output arrays, comma-separated. If not specified, "
+ "will try to read that information from the input file."),
+ Flag("output_array", parsed_flags.output_array.bind(),
+ parsed_flags.output_array.default_value(),
+ "Name of the output array, when specifying a unique output array. "
+ "If not specified, will try to read that information from the "
+ "input file."),
+ Flag("output_arrays", parsed_flags.output_arrays.bind(),
+ parsed_flags.output_arrays.default_value(),
+ "Names of the output arrays, comma-separated. "
+ "If not specified, will try to read "
+ "that information from the input file."),
+ Flag("input_shape", parsed_flags.input_shape.bind(),
+ parsed_flags.output_arrays.default_value(),
+ "Input array shape. For many models the shape takes the form "
+ "batch size, input array height, input array width, input array "
+ "depth."),
+ Flag("input_shapes", parsed_flags.input_shapes.bind(),
+ parsed_flags.input_shapes.default_value(),
+ "Shapes corresponding to --input_arrays, colon-separated. For "
+ "many models each shape takes the form batch size, input array "
+ "height, input array width, input array depth."),
+ Flag("mean_value", parsed_flags.mean_value.bind(),
+ parsed_flags.mean_value.default_value(),
+ "mean_value parameter for image models, used to compute input "
+ "activations from input pixel data."),
+ Flag("mean_values", parsed_flags.mean_values.bind(),
+ parsed_flags.mean_values.default_value(),
+ "mean_values parameter for image models, comma-separated list of "
+ "doubles, used to compute input activations from input pixel "
+ "data. Each entry in the list should match an entry in "
+ "--input_arrays."),
+ Flag("std_value", parsed_flags.std_value.bind(),
+ parsed_flags.std_value.default_value(),
+ "std_value parameter for image models, used to compute input "
+ "activations from input pixel data."),
+ Flag("std_values", parsed_flags.std_values.bind(),
+ parsed_flags.std_values.default_value(),
+ "std_value parameter for image models, comma-separated list of "
+ "doubles, used to compute input activations from input pixel "
+ "data. Each entry in the list should match an entry in "
+ "--input_arrays."),
+ Flag("variable_batch", parsed_flags.variable_batch.bind(),
+ parsed_flags.variable_batch.default_value(),
+ "If true, the model accepts an arbitrary batch size. Mutually "
+ "exclusive "
+ "with the 'batch' field: at most one of these two fields can be "
+ "set."),
+ Flag(
+ "drop_control_dependency",
+ parsed_flags.drop_control_dependency.bind(),
+ parsed_flags.drop_control_dependency.default_value(),
+ "If true, ignore control dependency requirements in input TensorFlow "
+ "GraphDef. Otherwise an error will be raised upon control dependency "
+ "inputs."),
+ Flag("rnn_states", parsed_flags.rnn_states.bind(),
+ parsed_flags.rnn_states.default_value(), ""),
+ Flag("model_checks", parsed_flags.model_checks.bind(),
+ parsed_flags.model_checks.default_value(),
+ "A list of model checks to be applied to verify the form of the "
+ "model. Applied after the graph transformations after import."),
+ Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(),
+ parsed_flags.graphviz_first_array.default_value(),
+ "If set, defines the start of the sub-graph to be dumped to "
+ "GraphViz."),
+ Flag(
+ "graphviz_last_array", parsed_flags.graphviz_last_array.bind(),
+ parsed_flags.graphviz_last_array.default_value(),
+ "If set, defines the end of the sub-graph to be dumped to GraphViz."),
+ Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
+ parsed_flags.dump_graphviz.default_value(),
+ "Dump graphviz during LogDump call. If string is non-empty then "
+ "it defines path to dump, otherwise will skip dumping."),
+ Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
+ parsed_flags.dump_graphviz_video.default_value(),
+ "If true, will dump graphviz at each "
+ "graph transformation, which may be used to generate a video."),
+ };
+ bool asked_for_help =
+ *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
+ if (asked_for_help) {
+ *msg += tensorflow::Flags::Usage(argv[0], flags);
+ return false;
+ } else {
+ if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
+ }
+ auto& dump_options = *GraphVizDumpOptions::singleton();
+ dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value();
+ dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value();
+ dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
+ dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
+
+ return true;
+}
+
+void ReadModelFlagsFromCommandLineFlags(
+ const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
+ toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
+
+// "batch" flag only exists internally
+#ifdef PLATFORM_GOOGLE
+ CHECK(!((base::SpecifiedOnCommandLine("batch") &&
+ parsed_model_flags.variable_batch.specified())))
+ << "The --batch and --variable_batch flags are mutually exclusive.";
+#endif
+ CHECK(!(parsed_model_flags.output_array.specified() &&
+ parsed_model_flags.output_arrays.specified()))
+ << "The --output_array and --vs flags are mutually exclusive.";
+
+ if (parsed_model_flags.output_array.specified()) {
+ model_flags->add_output_arrays(parsed_model_flags.output_array.value());
+ }
+
+ if (parsed_model_flags.output_arrays.specified()) {
+ std::vector<string> output_arrays =
+ absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
+ for (const string& output_array : output_arrays) {
+ model_flags->add_output_arrays(output_array);
+ }
+ }
+
+ const bool uses_single_input_flags =
+ parsed_model_flags.input_array.specified() ||
+ parsed_model_flags.mean_value.specified() ||
+ parsed_model_flags.std_value.specified() ||
+ parsed_model_flags.input_shape.specified();
+
+ const bool uses_multi_input_flags =
+ parsed_model_flags.input_arrays.specified() ||
+ parsed_model_flags.mean_values.specified() ||
+ parsed_model_flags.std_values.specified() ||
+ parsed_model_flags.input_shapes.specified();
+
+ QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
+ << "Use either the singular-form input flags (--input_array, "
+ "--input_shape, --mean_value, --std_value) or the plural form input "
+ "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
+ "but not both forms within the same command line.";
+
+ if (parsed_model_flags.input_array.specified()) {
+ QCHECK(uses_single_input_flags);
+ model_flags->add_input_arrays()->set_name(
+ parsed_model_flags.input_array.value());
+ }
+ if (parsed_model_flags.input_arrays.specified()) {
+ QCHECK(uses_multi_input_flags);
+ for (const auto& input_array :
+ absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
+ model_flags->add_input_arrays()->set_name(string(input_array));
+ }
+ }
+ if (parsed_model_flags.mean_value.specified()) {
+ QCHECK(uses_single_input_flags);
+ model_flags->mutable_input_arrays(0)->set_mean_value(
+ parsed_model_flags.mean_value.value());
+ }
+ if (parsed_model_flags.mean_values.specified()) {
+ QCHECK(uses_multi_input_flags);
+ std::vector<string> mean_values =
+ absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
+ QCHECK(mean_values.size() == model_flags->input_arrays_size());
+ for (int i = 0; i < mean_values.size(); ++i) {
+ char* last = nullptr;
+ model_flags->mutable_input_arrays(i)->set_mean_value(
+ strtod(mean_values[i].data(), &last));
+ CHECK(last != mean_values[i].data());
+ }
+ }
+ if (parsed_model_flags.std_value.specified()) {
+ QCHECK(uses_single_input_flags);
+ model_flags->mutable_input_arrays(0)->set_std_value(
+ parsed_model_flags.std_value.value());
+ }
+ if (parsed_model_flags.std_values.specified()) {
+ QCHECK(uses_multi_input_flags);
+ std::vector<string> std_values =
+ absl::StrSplit(parsed_model_flags.std_values.value(), ',');
+ QCHECK(std_values.size() == model_flags->input_arrays_size());
+ for (int i = 0; i < std_values.size(); ++i) {
+ char* last = nullptr;
+ model_flags->mutable_input_arrays(i)->set_std_value(
+ strtod(std_values[i].data(), &last));
+ CHECK(last != std_values[i].data());
+ }
+ }
+ if (parsed_model_flags.input_shape.specified()) {
+ QCHECK(uses_single_input_flags);
+ if (model_flags->input_arrays().empty()) {
+ model_flags->add_input_arrays();
+ }
+ auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
+ shape->Clear();
+ const IntList& list = parsed_model_flags.input_shape.value();
+ for (auto& dim : list.elements) {
+ shape->Add(dim);
+ }
+ }
+ if (parsed_model_flags.input_shapes.specified()) {
+ QCHECK(uses_multi_input_flags);
+ std::vector<string> input_shapes =
+ absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
+ QCHECK(input_shapes.size() == model_flags->input_arrays_size());
+ for (int i = 0; i < input_shapes.size(); ++i) {
+ auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
+ shape->Clear();
+ if (input_shapes[i].empty()) {
+ // empty i.e. 0-dimensional input shape.
+ // Unfortunately, the current toco::InputArray
+ // proto does not allow to distinguish between a known 0-D shape,
+ // and an unknown shape. Indeed, shape is currently a plain array,
+ // and it being empty means unknown shape. So here, we import a
+ // 0-D shape as a 1-D shape of size.
+ // TODO(benoitjacob): fix toco::InputArray to allow 0-D shape,
+ // probably by making shape an optional message,
+ // encapsulating the array.
+ shape->Add(1);
+ } else {
+ for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
+ int size;
+ CHECK(absl::SimpleAtoi(dim_str, &size))
+ << "Failed to parse input_shape: " << input_shapes[i];
+ shape->Add(size);
+ }
+ }
+ }
+ }
+
+#define READ_MODEL_FLAG(name) \
+ do { \
+ if (parsed_model_flags.name.specified()) { \
+ model_flags->set_##name(parsed_model_flags.name.value()); \
+ } \
+ } while (false)
+
+ READ_MODEL_FLAG(variable_batch);
+ READ_MODEL_FLAG(drop_control_dependency);
+
+#undef READ_MODEL_FLAG
+
+ for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
+ auto* rnn_state_proto = model_flags->add_rnn_states();
+ for (const auto& kv_pair : element) {
+ const string& key = kv_pair.first;
+ const string& value = kv_pair.second;
+ if (key == "state_array") {
+ rnn_state_proto->set_state_array(value);
+ } else if (key == "back_edge_source_array") {
+ rnn_state_proto->set_back_edge_source_array(value);
+ } else if (key == "size") {
+ int32 size = 0;
+ CHECK(absl::SimpleAtoi(value, &size));
+ CHECK_GT(size, 0);
+ rnn_state_proto->set_size(size);
+ } else if (key == "manually_create") {
+ CHECK_EQ(absl::AsciiStrToLower(value), "true");
+ rnn_state_proto->set_manually_create(true);
+ } else {
+ LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
+ }
+ }
+ CHECK(rnn_state_proto->has_state_array() &&
+ rnn_state_proto->has_back_edge_source_array() &&
+ rnn_state_proto->has_size())
+ << "--rnn_states must include state_array, back_edge_source_array and "
+ "size.";
+ }
+
+ for (const auto& element : parsed_model_flags.model_checks.value().elements) {
+ auto* model_check_proto = model_flags->add_model_checks();
+ for (const auto& kv_pair : element) {
+ const string& key = kv_pair.first;
+ const string& value = kv_pair.second;
+ if (key == "count_type") {
+ model_check_proto->set_count_type(value);
+ } else if (key == "count_min") {
+ int32 count = 0;
+ CHECK(absl::SimpleAtoi(value, &count));
+ CHECK_GE(count, -1);
+ model_check_proto->set_count_min(count);
+ } else if (key == "count_max") {
+ int32 count = 0;
+ CHECK(absl::SimpleAtoi(value, &count));
+ CHECK_GE(count, -1);
+ model_check_proto->set_count_max(count);
+ } else {
+ LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
+ }
+ }
+ }
+}
+
+ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
+ static auto* flags = [must_already_exist]() {
+ if (must_already_exist) {
+ fprintf(stderr, __FILE__
+ ":"
+ "GlobalParsedModelFlags() used without initialization\n");
+ fflush(stderr);
+ abort();
+ }
+ return new toco::ParsedModelFlags;
+ }();
+ return flags;
+}
+
+ParsedModelFlags* GlobalParsedModelFlags() {
+ return UncheckedGlobalParsedModelFlags(true);
+}
+
+void ParseModelFlagsOrDie(int* argc, char* argv[]) {
+ // TODO(aselle): in the future allow Google version to use
+ // flags, and only use this mechanism for open source
+ auto* flags = UncheckedGlobalParsedModelFlags(false);
+ string msg;
+ bool model_success =
+ toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
+ if (!model_success || !msg.empty()) {
+ // Log in non-standard way since this happens pre InitGoogle.
+ fprintf(stderr, "%s", msg.c_str());
+ fflush(stderr);
+ abort();
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/contrib/lite/toco/model_cmdline_flags.h
new file mode 100644
index 0000000000..dfa3d3c1ef
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+
+namespace toco {
+// Parse and remove arguments for models (in toco). Returns true if parsing
+// is successful. msg has the usage string if there was an error or
+// "--help" was specified
+bool ParseModelFlagsFromCommandLineFlags(
+ int* argc, char* argv[], string* msg,
+ ParsedModelFlags* parsed_model_flags_ptr);
+// Populate the ModelFlags proto with model data.
+void ReadModelFlagsFromCommandLineFlags(
+ const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags);
+// Parse the global model flags to a static
+void ParseModelFlagsOrDie(int* argc, char* argv[]);
+// Get the global parsed model flags
+ParsedModelFlags* GlobalParsedModelFlags();
+
+} // namespace toco
+
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
new file mode 100644
index 0000000000..743e08b16f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -0,0 +1,119 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+syntax = "proto2";
+
+package toco;
+
+// Next ID to USE: 5.
+message InputArray {
+ // Name of the input arrays, i.e. the arrays from which input activations
+ // will be read.
+ optional string name = 1;
+
+ // Shape of the input. For many applications the dimensions are {batch,
+ // height, width, depth}. Often the batch is left "unspecified" by providing
+ // a value of -1.
+ //
+ // The last dimension is typically called 'depth' or 'channels'. For example,
+ // for an image model taking RGB images as input, this would have the value 3.
+ repeated int32 shape = 2;
+
+ // mean_value and std_value parameters control the interpretation of raw input
+ // activation values (elements of the input array) as real numbers. The
+ // mapping is given by:
+ //
+ // real_value = (raw_input_value - mean_value) / std_value
+ //
+ // In particular, the defaults (mean_value=0, std_value=1) yield
+ // real_value = raw_input_value. Often, non-default values are used in image
+ // models. For example, an image model taking uint8 image channel values as
+ // its raw inputs, in [0, 255] range, may use mean_value=128, std_value=128 to
+ // map them into the interval [-1, 1).
+ //
+ // Note: this matches exactly the meaning of mean_value and std_value in
+ // (TensorFlow via LegacyFedInput).
+ optional float mean_value = 3;
+ optional float std_value = 4 [default = 1.];
+}
+
+// ModelFlags encodes properties of a model that, depending on the file
+// format, may or may not be recorded in the model file. The purpose of
+// representing these properties in ModelFlags is to allow passing them
+// separately from the input model file, for instance as command-line
+// parameters, so that we can offer a single uniform interface that can
+// handle files from different input formats.
+//
+// For each of these properties, and each supported file format, we
+// detail in comments below whether the property exists in the given file
+// format.
+//
+// Obsolete flags that have been removed:
+// optional int32 input_depth = 3;
+// optional int32 input_width = 4;
+// optional int32 input_height = 5;
+// optional int32 batch = 6 [ default = 1];
+// optional float mean_value = 7;
+// optional float std_value = 8 [default = 1.];
+// optional int32 input_dims = 11 [ default = 4];
+// repeated int32 input_shape = 13;
+//
+// Next ID to USE: 16.
+message ModelFlags {
+ // Information about the input arrays, i.e. the arrays from which input
+ // activations will be read.
+ repeated InputArray input_arrays = 1;
+
+ // Name of the output arrays, i.e. the arrays into which output activations
+ // will be written.
+ repeated string output_arrays = 2;
+
+ // If true, the model accepts an arbitrary batch size. Mutually exclusive with
+ // the 'batch' field: at most one of these two fields can be set.
+ optional bool variable_batch = 10;
+
+ message RnnState {
+ optional string state_array = 1;
+ optional string back_edge_source_array = 2;
+ optional int32 size = 3;
+ // TODO(benoitjacob): manually_create is a temporary hack:
+ // due to discrepancies between the current toco dims tracking and
+ // TensorFlow shapes, for some models we need to manually create RNN state
+ // arrays with a specified shape.
+ // Maybe we should actually implement back-edges as operators of their own,
+ // which would remove the need for much special-casing, including here,
+ // we could probably consistently let PropagateFixedSizes handle state
+ // arrays.
+ optional bool manually_create = 4;
+ }
+ repeated RnnState rnn_states = 12;
+
+ // Checks applied to the model, typically after toco's comprehensive
+ // graph transformations.
+ // Next ID to USE: 4.
+ message ModelCheck {
+ // Use the name of a type of operator to check its counts.
+ // Use "Total" for overall operator counts.
+ // Use "Arrays" for overall array counts.
+ optional string count_type = 1 [default = "None"];
+ // A count of zero is a meaningful check, so negative used to mean disable.
+ optional int32 count_min = 2 [default = -1];
+ // If count_max < count_min, then count_min is only allowed value.
+ optional int32 count_max = 3 [default = -1];
+ }
+ repeated ModelCheck model_checks = 14;
+
+ // If true, ignore control dependency requirements in input TensorFlow
+ // GraphDef. Otherwise an error will be raised upon control dependency inputs.
+ optional bool drop_control_dependency = 15;
+}
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
new file mode 100644
index 0000000000..92246a8aed
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -0,0 +1,76 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+cc_library(
+ name = "toco_python_api",
+ srcs = ["toco_python_api.cc"],
+ hdrs = ["toco_python_api.h"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:model_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_port",
+ "//tensorflow/contrib/lite/toco:toco_tooling",
+ "//tensorflow/core:lib",
+ "//util/python:python_headers",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "tensorflow_wrap_toco",
+ srcs = ["toco.i"],
+ deps = [
+ ":toco_python_api",
+ "//tensorflow/contrib/lite/toco:model_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
+ "//util/python:python_headers",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+py_binary(
+ name = "toco_from_protos",
+ srcs = ["toco_from_protos.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":tensorflow_wrap_toco",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_binary(
+ name = "toco_wrapper",
+ srcs = ["toco_wrapper.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+tf_py_test(
+ name = "toco_from_protos_test",
+ srcs = ["toco_from_protos_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/lite/toco:model_flags_proto_py",
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
+ ],
+ data = [
+ ":toco_from_protos",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/contrib/lite/toco/python/toco.i
new file mode 100644
index 0000000000..3787cba4a3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco.i
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "std_string.i"
+
+%{
+#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
+%}
+
+namespace toco {
+
+// Convert a model represented in `input_contents`. `model_flags_proto`
+// describes model parameters. `toco_flags_proto` describes conversion
+// parameters (see relevant .protos for more information). Returns a string
+// representing the contents of the converted model.
+PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
+ PyObject* toco_flags_proto_txt_raw,
+ PyObject* input_contents_txt_raw);
+
+} // namespace toco \ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos.py b/tensorflow/contrib/lite/toco/python/toco_from_protos.py
new file mode 100644
index 0000000000..c0b032083b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos.py
@@ -0,0 +1,63 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python console command to invoke TOCO from serialized protos."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco
+from tensorflow.python.platform import app
+
+FLAGS = None
+
+
+def execute(unused_args):
+ model_str = open(FLAGS.model_proto_file, "rb").read()
+ toco_str = open(FLAGS.toco_proto_file, "rb").read()
+ input_str = open(FLAGS.model_input_file, "rb").read()
+
+ output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str)
+ open(FLAGS.model_output_file, "wb").write(output_str)
+ sys.exit(0)
+
+
+def main():
+ global FLAGS
+ parser = argparse.ArgumentParser(
+ description="Invoke toco using protos as input.")
+ parser.add_argument(
+ "model_proto_file",
+ type=str,
+ help="File containing serialized proto that describes the model.")
+ parser.add_argument(
+ "toco_proto_file",
+ type=str,
+ help="File containing serialized proto describing how TOCO should run.")
+ parser.add_argument(
+ "model_input_file", type=str, help="Input model is read from this file.")
+ parser.add_argument(
+ "model_output_file",
+ type=str,
+ help="Result of applying TOCO conversion is written here.")
+
+ FLAGS, unparsed = parser.parse_known_args()
+
+ app.run(main=execute, argv=[sys.argv[0]] + unparsed)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
new file mode 100644
index 0000000000..2a593beeca
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -0,0 +1,96 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+import tensorflow as tf
+from tensorflow.contrib.lite.toco import model_flags_pb2
+from tensorflow.contrib.lite.toco import toco_flags_pb2
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import resource_loader
+
+
+def TensorName(x):
+ """Get the canonical (non foo:0 name)."""
+ return x.name.split(":")[0]
+
+
+class TocoFromProtosTest(googletest.TestCase):
+
+ def _run(self, sess, in_tensor, out_tensor, should_succeed):
+ """Use toco binary to check conversion from graphdef to tflite.
+
+ Args:
+ sess: Active TensorFlow session containing graph.
+ in_tensor: TensorFlow tensor to use as input.
+ out_tensor: TensorFlow tensor to use as output.
+ should_succeed: Whether this is a valid conversion.
+ """
+ # Build all protos and extract graphdef
+ graph_def = sess.graph_def
+ toco_flags = toco_flags_pb2.TocoFlags()
+ toco_flags.input_format = toco_flags_pb2.TENSORFLOW_GRAPHDEF
+ toco_flags.output_format = toco_flags_pb2.TFLITE
+ toco_flags.input_types.append(toco_flags_pb2.FLOAT)
+ toco_flags.inference_type = toco_flags_pb2.FLOAT
+ model_flags = model_flags_pb2.ModelFlags()
+ input_array = model_flags.input_arrays.add()
+ input_array.name = TensorName(in_tensor)
+ input_array.shape.extend(map(int, in_tensor.get_shape()))
+ model_flags.output_arrays.append(TensorName(out_tensor))
+ # Shell out to run toco (in case it crashes)
+ with tempfile.NamedTemporaryFile() as fp_toco, \
+ tempfile.NamedTemporaryFile() as fp_model, \
+ tempfile.NamedTemporaryFile() as fp_input, \
+ tempfile.NamedTemporaryFile() as fp_output:
+ fp_model.write(model_flags.SerializeToString())
+ fp_toco.write(toco_flags.SerializeToString())
+ fp_input.write(graph_def.SerializeToString())
+ fp_model.flush()
+ fp_toco.flush()
+ fp_input.flush()
+ tflite_bin = resource_loader.get_path_to_datafile("toco_from_protos")
+ cmdline = " ".join([
+ tflite_bin, fp_model.name, fp_toco.name, fp_input.name, fp_output.name
+ ])
+ exitcode = os.system(cmdline)
+ if exitcode == 0:
+ stuff = fp_output.read()
+ self.assertEqual(stuff is not None, should_succeed)
+ else:
+ self.assertFalse(should_succeed)
+
+ def test_toco(self):
+ """Run a couple of TensorFlow graphs against TOCO through the python bin."""
+ with tf.Session() as sess:
+ img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+ val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+ out = tf.identity(val, name="out")
+ out2 = tf.sin(val, name="out2")
+ # This is a valid mdoel
+ self._run(sess, img, out, True)
+ # This uses an invalid function.
+ # TODO(aselle): Check to make sure a warning is included.
+ self._run(sess, img, out2, True)
+ # This is an identity graph, which doesn't work
+ self._run(sess, img, img, False)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
new file mode 100644
index 0000000000..8a5e483f3f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
@@ -0,0 +1,85 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include <vector>
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_tooling.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+namespace toco {
+
+#if PY_MAJOR_VERSION >= 3
+#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize
+#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize
+#else
+#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize
+#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize
+#endif
+
+// NOTE(aselle): We are using raw PyObject's here because we want to make
+// sure we input and output bytes rather than unicode strings for Python3.
+PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
+ PyObject* toco_flags_proto_txt_raw,
+ PyObject* input_contents_txt_raw) {
+ // Use Python C API to validate and convert arguments. In py3 (bytes),
+ // in py2 (str).
+ auto ConvertArg = [&](PyObject* obj, bool* error) {
+ char* buf;
+ Py_ssize_t len;
+ if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) {
+ *error = true;
+ return std::string();
+ } else {
+ *error = false;
+ return std::string(buf, len);
+ }
+ };
+
+ bool error;
+ std::string model_flags_proto_txt =
+ ConvertArg(model_flags_proto_txt_raw, &error);
+ if (error) return nullptr;
+ std::string toco_flags_proto_txt =
+ ConvertArg(toco_flags_proto_txt_raw, &error);
+ if (error) return nullptr;
+ std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
+ if (error) return nullptr;
+
+ // Use toco to produce new outputs
+ toco::ModelFlags model_flags;
+ if (!model_flags.ParseFromString(model_flags_proto_txt)) {
+ LOG(FATAL) << "Model proto failed to parse." << std::endl;
+ }
+ toco::TocoFlags toco_flags;
+ if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
+ LOG(FATAL) << "Toco proto failed to parse." << std::endl;
+ }
+ std::unique_ptr<toco::Model> model =
+ toco::Import(toco_flags, model_flags, input_contents_txt);
+ toco::Transform(toco_flags, model.get());
+ string output_file_contents_txt;
+ Export(toco_flags, *model, &output_file_contents_txt);
+
+ // Convert arguments back to byte (py3) or str (py2)
+ return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
+ output_file_contents_txt.size());
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h
new file mode 100644
index 0000000000..dc378353f7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+
+#include <string>
+#include <Python.h>
+
+namespace toco {
+
+// Convert a model represented in `input_contents`. `model_flags_proto`
+// describes model parameters. `toco_flags_proto` describes conversion
+// parameters (see relevant .protos for more information). Returns a string
+// representing the contents of the converted model.
+PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
+ PyObject* toco_flags_proto_txt_raw,
+ PyObject* input_contents_txt_raw);
+
+} // namespace toco
+
+#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py
new file mode 100644
index 0000000000..e39b5f22c7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/python/toco_wrapper.py
@@ -0,0 +1,35 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrapper for runninmg toco binary embedded in pip site-package.
+
+NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/.
+It can only install Python "console-scripts." This will work as a console
+script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import tensorflow as tf
+
+
+def main():
+ # Pip installs the binary in aux-bin off of main site-package install.
+ # Just find it and exec, passing all arguments in the process.
+ # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary.
+ binary = os.path.join(tf.__path__[0], 'aux-bin/toco')
+ os.execvp(binary, sys.argv)
diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/contrib/lite/toco/runtime/common.h
new file mode 100644
index 0000000000..bd55544f57
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/runtime/common.h
@@ -0,0 +1,26 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+
+#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#endif
+#endif
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h
new file mode 100644
index 0000000000..df63b2d59e
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/runtime/types.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace toco {
+
+// TODO(ahentz): These are just stopgaps for now, untils we move all
+// the code over to tflite.
+using tflite::Dims;
+using tflite::FusedActivationFunctionType;
+using tflite::RequiredBufferSizeForDims;
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
new file mode 100644
index 0000000000..0c1a1141fc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
@@ -0,0 +1,102 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "cluster_utils",
+ srcs = [
+ "cluster_utils.cc",
+ ],
+ hdrs = [
+ "cluster_utils.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/toco:toco_port",
+ ],
+)
+
+cc_library(
+ name = "cluster",
+ srcs = [
+ "cluster.cc",
+ ],
+ hdrs = [
+ "cluster.h",
+ ],
+ deps = [
+ ":cluster_utils",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "resolve_svdf",
+ srcs = [
+ "resolve_svdf.cc",
+ ],
+ hdrs = [
+ "resolve_svdf.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster",
+ ":cluster_utils",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:toco_port",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+tf_cc_test(
+ name = "resolve_svdf_test",
+ srcs = ["resolve_svdf_test.cc"],
+ deps = [
+ ":cluster",
+ ":cluster_utils",
+ ":resolve_cluster",
+ ":resolve_svdf",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "resolve_cluster",
+ srcs = [
+ "resolve_cluster.cc",
+ ],
+ hdrs = [
+ "resolve_cluster.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster",
+ ":cluster_utils",
+ ":resolve_svdf",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc
new file mode 100644
index 0000000000..98a130ea39
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+
+namespace toco {
+
+void Cluster::SetGraphDefInfo(const tensorflow::GraphDef* graph_def) {
+ graph_def_ = graph_def;
+ for (const tensorflow::NodeDef& node : graph_def_->node()) {
+ if (StrContains(node.name(), name_)) {
+ nodes_.push_back(&node);
+ }
+ }
+}
+
+bool Cluster::FindClusterInputsAndOutputs() {
+ // For every node N in the graph:
+ // If N belongs to this cluster C, then each of N's inputs that are not part
+ // of C are then inputs of C.
+ // If N does not belong to cluster C, then each of N's inputs that belong to C
+ // are then outputs of C.
+ for (const tensorflow::NodeDef& node : graph_def_->node()) {
+ if (StrContains(node.name(), name_)) {
+ for (int i = 0; i < node.input_size(); i++) {
+ if (!StrContains(node.input(i), name_)) {
+ inputs_.push_back(node.input(i));
+ }
+ }
+ } else {
+ for (int i = 0; i < node.input_size(); i++) {
+ if (StrContains(node.input(i), name_)) {
+ outputs_.push_back(node.input(i));
+ }
+ }
+ }
+ }
+ return (!inputs_.empty()) && (!outputs_.empty());
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
new file mode 100644
index 0000000000..18ff73ac39
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
@@ -0,0 +1,101 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+
+// The base class for Cluster. A cluster is group of nodes all related to each
+// other because their name match a given "pattern", which shows they all belong
+// to a composite op supported in TFLite. The nodes in a cluster will be
+// collapsed into a single composite op node plus a series of constant nodes
+// holding the input parameters to that node. The nodes in a cluster are assumed
+// to be using the same device. By changing the "pattern" we can have different
+// subclasses of the base Cluster class.
+class Cluster {
+ public:
+ virtual ~Cluster() {}
+
+ virtual void CreateNodes() = 0;
+
+ // Save the following info from the original GraphDef this cluster is from:
+ // 1- a pointer to the GraphDef
+ // 2- All the nodes in GraphDef which belong to this cluster.
+ void SetGraphDefInfo(const tensorflow::GraphDef* graph_def);
+
+ const string& GetName() const { return name_; }
+
+ const std::vector<std::unique_ptr<tensorflow::NodeDef>>& GetNewNodes() const {
+ return new_nodes_;
+ }
+
+ const std::vector<const tensorflow::NodeDef*>& GetNodes() { return nodes_; }
+
+ void SetName(const string& name) { name_ = name; }
+
+ void SetDevice(const string& device) { device_ = device; }
+
+ // Find the input(s) and output(s) of this Cluster.
+ bool FindClusterInputsAndOutputs();
+
+ protected:
+ string name_;
+ string device_;
+ std::vector<string> inputs_;
+ std::vector<string> outputs_;
+
+ // Used to hold the pointers to nodes which are in this cluster. These nodes
+ // are pointing to the nodes in graph_def_.
+ std::vector<const tensorflow::NodeDef*> nodes_;
+
+ // Used to cache the newly generated nodes: like the nodes created by
+ // collapsing Const nodes, or the nodes which is used to show the composite
+ // op.
+ std::vector<std::unique_ptr<tensorflow::NodeDef>> new_nodes_;
+
+ const tensorflow::GraphDef* graph_def_; /*Not owned*/
+};
+
+// A factory interface for cluster class.
+// It defines a virtual function interface which is responsible for creating
+// a cluster. Each cluster factory is responsible to pack a cluster of nodes
+// into a cluster using a name-based pattern matching approach.
+class ClusterFactoryInterface {
+ public:
+ virtual ~ClusterFactoryInterface() {}
+
+ // Creates a cluster of nodes using a name-based pattern matching approach. It
+ // uses a node as a seed and if its name matches a certain pattern, then it
+ // builds the cluster around that node.
+ virtual std::unique_ptr<Cluster> CreateCluster(
+ const tensorflow::NodeDef& node,
+ const tensorflow::GraphDef& graph_def) const = 0;
+};
+
+} // end namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc
new file mode 100644
index 0000000000..14c3cd6487
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+namespace toco {
+
+bool StrContains(const string& x, const string& search_pattern) {
+ return x.find(search_pattern) != string::npos;
+}
+
+void Transpose2DTensor(const float* tensor, int row, int col,
+ float* transposed_tensor) {
+ float* result = transposed_tensor;
+ for (int r = 0; r < row; ++r) {
+ for (int c = 0; c < col; ++c) {
+ *(result + c * row) = *tensor++;
+ }
+ ++result;
+ }
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
new file mode 100644
index 0000000000..a15e480e70
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
+
+#include <string>
+
+namespace toco {
+
+// Check if string x includes string search_pattern.
+bool StrContains(const string& x, const string& search_pattern);
+
+// Transpose a 2D tensor of size row * col pointed by "tensor" and return the
+// results in "transposed_tensor". "transposed_tensor" must be pre-allocated
+// by the same size as "tensor".
+void Transpose2DTensor(const float* tensor, int row, int col,
+ float* transposed_tensor);
+
+} // end namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc
new file mode 100644
index 0000000000..fddf6cc836
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+void AddNodeToGraph(const NodeDef& node,
+ const std::vector<string>& cluster_names, GraphDef* graph) {
+ NodeDef* new_node = graph->add_node();
+ new_node->set_op(node.op());
+ new_node->set_name(node.name());
+ new_node->set_device(node.device());
+ // If the inputs are coming from a node which belongs to another cluster, then
+ // those inputs are renamed to the source cluster name. Otherwise the original
+ // input name is used.
+ for (const string& node_input : node.input()) {
+ bool input_from_cluster = false;
+ for (const string& cluster_name : cluster_names) {
+ if (StrContains(node_input, cluster_name) &&
+ !StrContains(node.name(), cluster_name)) {
+ new_node->add_input(cluster_name);
+ input_from_cluster = true;
+ break;
+ }
+ }
+ if (!input_from_cluster) {
+ new_node->add_input(node_input);
+ }
+ }
+ for (const auto& attr : node.attr()) {
+ (*new_node->mutable_attr())[attr.first] = attr.second;
+ }
+}
+
+bool FindCluster(const ClusterFactoryInterface& cluster_factory,
+ const GraphDef& graph_def,
+ std::unordered_map<string, bool>* is_node_in_cluster,
+ std::vector<std::unique_ptr<Cluster>>* clusters) {
+ for (const NodeDef& node : graph_def.node()) {
+ // If the node is not assigned to any cluster, then we check if it belong to
+ // the cluster_factory.
+ bool node_in_cluster = (*is_node_in_cluster)[node.name()];
+ if (!node_in_cluster) {
+ std::unique_ptr<Cluster> cluster =
+ cluster_factory.CreateCluster(node, graph_def);
+ if (cluster) {
+ // Label all the nodes in is_node_in_cluster which are in this cluster
+ // as belonged to this cluster.
+ for (const NodeDef* cluster_node : cluster->GetNodes()) {
+ (*is_node_in_cluster)[cluster_node->name()] = true;
+ }
+ clusters->push_back(std::move(cluster));
+ }
+ }
+ }
+ return (!clusters->empty());
+}
+
+std::unique_ptr<GraphDef> MaybeResolveClusters(
+ const GraphDef& graph_def,
+ const std::vector<ClusterFactoryInterface*>& cluster_factories) {
+ std::unique_ptr<GraphDef> pruned_graph(new GraphDef);
+ // The structure to keep track of which cluster each node is assigned to, and
+ // to initialize them to all un-assigned,
+ std::unordered_map<string, bool> is_node_in_cluster;
+ for (const NodeDef& node : graph_def.node()) {
+ is_node_in_cluster[node.name()] = false;
+ }
+
+ std::vector<string> cluster_names;
+ std::vector<std::unique_ptr<Cluster>> all_clusters;
+ // Find the clusters for all available cluster factories.
+ for (const ClusterFactoryInterface* cluster_factory : cluster_factories) {
+ std::vector<std::unique_ptr<Cluster>> clusters;
+ if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster,
+ &clusters)) {
+ for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) {
+ cluster_names.push_back((*itr)->GetName());
+ (*itr)->CreateNodes();
+ all_clusters.push_back(std::move(*itr));
+ }
+ }
+ }
+
+ for (const std::unique_ptr<Cluster>& cluster : all_clusters) {
+ for (const std::unique_ptr<tensorflow::NodeDef>& src_node :
+ cluster->GetNewNodes()) {
+ // Add it to the output GraphDef.
+ AddNodeToGraph(*src_node, cluster_names, pruned_graph.get());
+ }
+ }
+
+ // Add any node which is not part of a cluster.
+ for (const NodeDef& node : graph_def.node()) {
+ bool node_in_cluster = is_node_in_cluster[node.name()];
+ if (!node_in_cluster) {
+ AddNodeToGraph(node, cluster_names, pruned_graph.get());
+ }
+ }
+
+ if (pruned_graph->node_size() == 0) {
+ return nullptr;
+ } else {
+ return pruned_graph;
+ }
+}
+
+std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph(
+ const GraphDef& tf_graph) {
+ SvdfClusterFactory svdf_cluster_factory;
+
+ std::vector<ClusterFactoryInterface*> cluster_factories;
+ cluster_factories.push_back(&svdf_cluster_factory);
+
+ std::unique_ptr<GraphDef> pruned_graph =
+ MaybeResolveClusters(tf_graph, cluster_factories);
+
+ // Copy function definitions
+ *(pruned_graph->mutable_library()) = tf_graph.library();
+ return pruned_graph;
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
new file mode 100644
index 0000000000..7d33dd1885
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+// Given a graph info and a list of cluster classes (cluster_factories), it
+// partitions the graph to clusters, and then collapses each cluster into their
+// corresponding composite ops. It generates a new graph using the newly
+// generated composite ops. Each cluster factory is responsible to recognize a
+// cluster of nodes into a cluster using a name-based pattern matching approach.
+std::unique_ptr<tensorflow::GraphDef> MaybeResolveClusters(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<ClusterFactoryInterface*>& cluster_factories);
+
+// Adds a node to a given graph. The added node will be a copy of a given source
+// node, except for the inputs. If the inputs are coming from a node which
+// belongs to another cluster, then those inputs are renamed to the source
+// cluster name.
+void AddNodeToGraph(const tensorflow::NodeDef& node,
+ const std::vector<string>& cluster_names,
+ tensorflow::GraphDef* graph);
+
+// Given a graph and a cluster class, it finds all the nodes which belong to a
+// given class factory, encapsulate them inside a cluster of the given type and
+// returns a vector of those clusters. It also labels the nodes in that graph if
+// they belong to the generated clusters.
+bool FindCluster(const ClusterFactoryInterface& cluster_factory,
+ const tensorflow::GraphDef& graph_def,
+ std::unordered_map<string, bool>* is_node_in_cluster,
+ std::vector<std::unique_ptr<Cluster>>* clusters);
+
+// Receives a graph and generates another graph by replacing the cluster of
+// nodes which matches a given composite op. Each composite op is represented
+// using a class factory.
+std::unique_ptr<tensorflow::GraphDef> MaybeReplaceCompositeSubgraph(
+ const tensorflow::GraphDef& tf_graph);
+
+} // end namespace toco
+
+#endif // CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
new file mode 100644
index 0000000000..d6a099817c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc
@@ -0,0 +1,285 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+
+#include <ctype.h>
+#include <stddef.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "google/protobuf/map.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+namespace toco {
+
+namespace {
+
+// Receives a vector of cluster nodes and returns only those which are array
+// partitions (of type 'Const' and have the pattern 'part_<.*>' in their name.
+// Since these nodes are connected to a Concatenate node, it makes sure the
+// axis value input of the Concatenate operator is 0.
+void FilterPartitionedConstNodes(
+ const string& const_pattern,
+ const std::vector<const NodeDef*>& cluster_nodes,
+ std::vector<const NodeDef*>* const_node_parts) {
+ for (const NodeDef* node : cluster_nodes) {
+ string node_name_to_upper = node->name();
+ std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
+ node_name_to_upper.begin(), ::toupper);
+ if (StrContains(node->name(), const_pattern) && node->op() == "Const") {
+ if (StrContains(node_name_to_upper, "/PART_")) {
+ const_node_parts->push_back(node);
+ } else if (StrContains(node->name(), "AXIS") &&
+ StrContains(node->name(), "CONCAT")) {
+ // For now only supporting Concatenate on Axix 0
+ const auto& value_attr = node->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ CHECK_EQ(tensor.int_val(0), 0);
+ }
+ }
+ }
+ sort(const_node_parts->begin(), const_node_parts->end(),
+ [](const NodeDef* a, const NodeDef* b) {
+ return (a->name().compare(b->name()) < 0 &&
+ (a->name().size() < b->name().size()));
+ });
+}
+
+} // namespace
+
+// SvdfCluster methods
+
+int SvdfCluster::InferFilterRank() {
+ for (const NodeDef* node : nodes_) {
+ if (StrContains(node->name(), "Reshape/shape")) {
+ const auto& value_attr = node->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ std::vector<int32> shape_values(
+ tensor.tensor_content().size() / sizeof(int), 0);
+ port::CopyToBuffer(tensor.tensor_content(),
+ reinterpret_cast<char*>(shape_values.data()));
+ CHECK_EQ(shape_values.size(), 3);
+ // shape_value array is arranged as:
+ // [num_units, rank, -1]
+ CHECK_EQ(shape_values[2], -1);
+ return shape_values[1];
+ }
+ }
+ return -1;
+}
+
+void SvdfCluster::CreateNodes() {
+ for (const string& const_pattern : const_node_patterns_) {
+ CreateConstNode(const_pattern);
+ }
+ std::unique_ptr<tensorflow::NodeDef> svdf_node(new NodeDef);
+ svdf_node->set_op("Svdf");
+ svdf_node->set_name(name_);
+ svdf_node->set_device(device_);
+
+ // Add the main input.
+ svdf_node->add_input(inputs_[0]);
+
+ // Add the rest of the inputs to Svdf cell: weights and bias.
+ CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2);
+ string* weights_feature_input = svdf_node->add_input();
+ string* weights_time_input = svdf_node->add_input();
+ string* bias_input;
+ if (new_nodes_.size() == 3) {
+ bias_input = svdf_node->add_input();
+ }
+ for (const std::unique_ptr<tensorflow::NodeDef>& node : new_nodes_) {
+ const string node_name = node->name();
+ if (StrContains(node_name, "SVDF_weights_feature")) {
+ *weights_feature_input = node_name;
+ } else if (StrContains(node_name, "SVDF_weights_time")) {
+ *weights_time_input = node_name;
+ } else if (StrContains(node_name, "SVDF_bias")) {
+ CHECK(bias_input) << "Bias input cannot be provided when there are only "
+ "two Const input nodes!";
+ *bias_input = node_name;
+ } else {
+ // Unexpected input for Svdf op.
+ LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: "
+ "weights_feature, weights_time and bias.";
+ }
+ }
+ const int rank = InferFilterRank();
+ CHECK_GT(rank, 0);
+
+ // Add Svdf activation and rank.
+ string activation_function =
+ StrContains(outputs_[0], "Relu") ? "Relu" : "None";
+ (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function);
+ (*svdf_node->mutable_attr())["Rank"].set_i(rank);
+
+ // Finally add it to the list of the newly created nodes.
+ new_nodes_.push_back(std::move(svdf_node));
+}
+
+void SvdfCluster::CreateConstNode(const string& const_pattern) {
+ // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const.
+ std::vector<const NodeDef*> const_node_parts;
+ FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts);
+
+ if (const_node_parts.empty()) return;
+
+ bool transpose_tensor_value =
+ StrContains(const_pattern, "SVDF_weights_feature");
+
+ // Merge them if necessary.
+ std::unique_ptr<tensorflow::NodeDef> merged_node(new NodeDef);
+ MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node);
+ new_nodes_.push_back(std::move(merged_node));
+}
+
+void SvdfCluster::MaybeMergeConstNodes(
+ const std::vector<const NodeDef*>& const_node_parts,
+ bool transpose_tensor_value,
+ const std::unique_ptr<tensorflow::NodeDef>& merged_node) {
+ merged_node->set_name(const_node_parts[0]->name());
+ merged_node->set_op("Const");
+ merged_node->set_device(const_node_parts[0]->device());
+ (*merged_node->mutable_attr())["dtype"].set_type(
+ const_node_parts[0]->attr().at("dtype").type());
+
+ // Figuring out Value attribute for the merged node.
+ // Assuming the partitioning is done on Axis 0.
+ // The attributes which are inferred:
+ // * Shape and dimensions
+ // * Float content values
+
+ // Inferring shape and dimension
+ int dim0_size = 0;
+ int dim1_size = 1;
+ tensorflow::TensorProto* allocated_tensor =
+ (*merged_node->mutable_attr())["value"].mutable_tensor();
+ tensorflow::TensorShapeProto* allocated_tensor_shape =
+ allocated_tensor->mutable_tensor_shape();
+ auto tensor_shape_dim0 = allocated_tensor_shape->add_dim();
+ int allocated_content_flat_size = 0;
+ for (int i = 0; i < const_node_parts.size(); i++) {
+ const auto& value_attr = const_node_parts[i]->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ if (i == 0) {
+ allocated_tensor->set_dtype(tensor.dtype());
+ } else {
+ CHECK_EQ(allocated_tensor->dtype(), tensor.dtype());
+ }
+ allocated_content_flat_size += tensor.tensor_content().size();
+ CHECK(tensor.has_tensor_shape());
+ const tensorflow::TensorShapeProto shape = tensor.tensor_shape();
+ dim0_size += shape.dim(0).size();
+ for (int d = 1; d < shape.dim_size(); d++) {
+ if (i == 0) {
+ allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size());
+ allocated_tensor_shape->set_unknown_rank(shape.unknown_rank());
+ dim1_size *= shape.dim(d).size();
+ } else {
+ CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size());
+ CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank());
+ }
+ }
+ }
+
+ // Copying the float content from each array partition.
+ std::unique_ptr<char[]> allocated_content(
+ new char[allocated_content_flat_size]);
+ char* content_ptr = allocated_content.get();
+ for (int i = 0; i < const_node_parts.size(); i++) {
+ const auto& value_attr = const_node_parts[i]->attr().at("value");
+ const tensorflow::TensorProto& tensor = value_attr.tensor();
+ port::CopyToBuffer(tensor.tensor_content(), content_ptr);
+ content_ptr += tensor.tensor_content().size();
+ }
+
+ // Transpose the tensor if needed.
+ if (transpose_tensor_value) {
+ // We use dimension 0 to show the row size for the tensor.
+ // We use multiplication of the rest of dimension size to for the col size
+ // of the tensor.
+ std::unique_ptr<float[]> transposed_tensor(
+ new float[dim0_size * dim1_size]);
+ Transpose2DTensor(reinterpret_cast<float*>(allocated_content.get()),
+ dim0_size, dim1_size, transposed_tensor.get());
+ allocated_tensor_shape->clear_dim();
+ allocated_tensor_shape->add_dim()->set_size(dim1_size);
+ allocated_tensor_shape->add_dim()->set_size(dim0_size);
+
+ // Set the tensor attributes.
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(transposed_tensor.get()),
+ allocated_content_flat_size));
+ } else {
+ tensor_shape_dim0->set_size(dim0_size);
+
+ // Set the tensor attributes.
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(allocated_content.get()),
+ allocated_content_flat_size));
+ }
+}
+
+// SvdfClusterFactory methods
+
+std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
+ const NodeDef& node, const GraphDef& graph_def) const {
+ std::vector<string> node_patterns = {"SVDF_weights_feature",
+ "SVDF_weights_time", "SVDF_bias"};
+
+ string node_name_to_upper = node.name();
+ std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
+ node_name_to_upper.begin(), ::toupper);
+ std::unique_ptr<SvdfCluster> cluster = nullptr;
+ if (node_name_to_upper.find("SVDF", 0) != string::npos) {
+ size_t weights_pos = node.name().find(node_patterns[0]);
+ if (weights_pos != string::npos) {
+ // Assuming the node name has a pattern like:
+ // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
+ // CELLNAME as the cluster name.
+ size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1;
+ string cell_name =
+ node.name().substr(cell_pos, weights_pos - cell_pos - 1);
+ cluster = std::unique_ptr<SvdfCluster>(new SvdfCluster);
+ cluster->SetName(cell_name);
+ cluster->SetDevice(node.device());
+ cluster->SetGraphDefInfo(&graph_def);
+ CHECK(cluster->FindClusterInputsAndOutputs());
+
+ for (const string& const_pattern : node_patterns) {
+ cluster->AddConstNodePattern(const_pattern);
+ }
+ }
+ }
+ return std::move(cluster);
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
new file mode 100644
index 0000000000..c4c6c34117
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+class SvdfCluster : public Cluster {
+ public:
+ // For this cluster, it collapses all the nodes in nodes_ into a composite op
+ // and it returns all the newly generated ops in new_nodes_.
+ void CreateNodes() override;
+
+ // A helper function to set the pattern of Const nodes which CreateNodes()
+ // should handle specially.
+ void AddConstNodePattern(const string& const_pattern) {
+ const_node_patterns_.push_back(const_pattern);
+ }
+
+ virtual ~SvdfCluster() {}
+
+ private:
+ // The main function which is used to create Const nodes for this cluster.
+ // These Const nodes are the inputs to the composite op generated for this
+ // cluster.
+ void CreateConstNode(const string& const_pattern);
+
+ // Receives a vector of Const nodes, merge them (if necessary) and returns
+ // only one Const node holding all the arrays contents. It transposes it if
+ // needed.
+ void MaybeMergeConstNodes(
+ const std::vector<const tensorflow::NodeDef*>& const_node_parts,
+ bool transpose_tensor_value,
+ const std::unique_ptr<tensorflow::NodeDef>& merged_node);
+
+ // Infer the value of Svdf filter rank, by looking up a reshape operator which
+ // is used for 'output' which reshapes output from [num_filters, batch, 1]
+ // shape to [num_units, rank, batch] shape. The 2nd shape element is rank.
+ int InferFilterRank();
+
+ std::vector<string> const_node_patterns_;
+};
+
+class SvdfClusterFactory : public ClusterFactoryInterface {
+ public:
+ // Creates a cluster of nodes using a name-based pattern matching approach. It
+ // uses a node as a seed and if its name matches a certain pattern, then it
+ // builds the cluster around that node.
+ // This factory expects nodes which have "SVDF_weights_feature" and
+ // "SVDF_weights_time" pattern in their names (and optionally "SVDF_bias")
+ // and it creates an SVDF Op from them.
+ std::unique_ptr<Cluster> CreateCluster(
+ const tensorflow::NodeDef& node,
+ const tensorflow::GraphDef& graph_def) const;
+};
+
+} // end namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc
new file mode 100644
index 0000000000..664e828c19
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc
@@ -0,0 +1,212 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
+#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+
+namespace toco {
+
+class ResolveSvdfTest : public ::testing::Test {
+ public:
+ ResolveSvdfTest() {
+ AddNewNode("Input1", "Const", {});
+ AddNewNode("Svdf1/SVDF_weights_feature/part_0", "Const", {},
+ {0.1, 0.2, 0.3});
+ AddNewNode("Svdf1/SVDF_weights_feature/part_0/read", "Identity",
+ {"Svdf1/SVDF_weights_feature/part_0"});
+ AddNewNode("Svdf1/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3});
+ AddNewNode("Svdf1/SVDF_weights_time/part_0/read", "Identity",
+ {"Svdf1/SVDF_weights_time/part_0"});
+
+ AddNewNode("Svdf1/f1", "SVDF_F1",
+ {"Input1", "Svdf1/SVDF_weights_feature/part_0/read"});
+ AddNewNode("Svdf1/f2", "SVDF_F2",
+ {"Svdf1/SVDF_weights_time/part_0/read", "Svdf1/f1"});
+ AddNewNode("Svdf1/Relu", "Relu", {"Svdf1/f2"});
+ AddShapeNode("Svdf1/Reshape/shape", {10, 1, -1});
+ AddNewNode("Output1", "Const", {"Svdf1/Relu"});
+
+ AddNewNode("Input2", "Const", {});
+ AddNewNode("Svdf2/SVDF_weights_feature/part_0", "Const", {},
+ {0.1, 0.2, 0.3});
+ AddNewNode("Svdf2/SVDF_weights_feature/part_0/read", "Identity",
+ {"Svdf2/SVDF_weights_feature/part_0"});
+ AddNewNode("Svdf2/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3});
+ AddNewNode("Svdf2/SVDF_weights_time/part_0/read", "Identity",
+ {"Svdf2/SVDF_weights_time/part_0"});
+
+ AddNewNode("Svdf2/f1", "SVDF_F1",
+ {"Input1", "Svdf2/SVDF_weights_feature/part_0/read"});
+ AddNewNode("Svdf2/f2", "SVDF_F2",
+ {"Svdf2/SVDF_weights_time/part_0/read", "Svdf2/f1"});
+ AddNewNode("Svdf2/Relu", "Relu", {"Svdf2/f2"});
+ AddShapeNode("Svdf2/Reshape/shape", {10, 2, -1});
+ AddNewNode("Output2", "Const", {"Svdf2/Relu"});
+ }
+
+ ~ResolveSvdfTest() override {}
+
+ protected:
+ void AddNewNode(const string& name, const string& op,
+ const std::vector<string>& inputs) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(name);
+ node->set_op(op);
+ node->set_device("");
+ for (int i = 0; i < inputs.size(); i++) {
+ node->add_input();
+ node->set_input(i, inputs[i]);
+ }
+ }
+
+ void AddNewNode(const string& name, const string& op,
+ const std::vector<string>& inputs,
+ const std::vector<float>& values) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(name);
+ node->set_op(op);
+ node->set_device("");
+ for (int i = 0; i < inputs.size(); i++) {
+ node->add_input();
+ node->set_input(i, inputs[i]);
+ }
+ // Add the float vector as an attribute to the node.
+ (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_FLOAT);
+ tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto;
+ tensorflow::TensorShapeProto* allocated_tesnor_shape =
+ new tensorflow::TensorShapeProto;
+ auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim();
+ tensor_shape_dim0->set_size(values.size());
+ allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape);
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(values.data()),
+ values.size() * sizeof(float)));
+ (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor);
+ }
+
+ void AddShapeNode(const string& name, const std::vector<int>& values) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(name);
+ node->set_op("Const");
+ node->set_device("");
+ // Add the float vector as an attribute to the node.
+ (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_INT32);
+ tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto;
+ tensorflow::TensorShapeProto* allocated_tesnor_shape =
+ new tensorflow::TensorShapeProto;
+ auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim();
+ tensor_shape_dim0->set_size(values.size());
+ allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape);
+ allocated_tensor->set_tensor_content(
+ string(reinterpret_cast<const char*>(values.data()),
+ values.size() * sizeof(int)));
+ (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor);
+ }
+
+ GraphDef graph_;
+ SvdfClusterFactory svdf_cluster_factory_;
+ std::vector<std::unique_ptr<Cluster>> clusters_;
+};
+
+TEST_F(ResolveSvdfTest, TestTranspose2DTensor) {
+ static float matrix[] = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.};
+ static float expected_transposed_matrix[] = {1., 5., 9., 2., 6., 10.,
+ 3., 7., 11., 4., 8., 12.};
+ float* transposed_matrix = new float[12];
+ Transpose2DTensor(matrix, 3, 4, transposed_matrix);
+
+ std::vector<float> actual;
+ actual.insert(
+ actual.end(), transposed_matrix,
+ transposed_matrix + sizeof(expected_transposed_matrix) / sizeof(float));
+ std::vector<float> expected;
+ expected.insert(expected.end(), expected_transposed_matrix,
+ expected_transposed_matrix +
+ sizeof(expected_transposed_matrix) / sizeof(float));
+ delete[] transposed_matrix;
+}
+
+TEST_F(ResolveSvdfTest, TestResolveSvdfFlow) {
+ std::unordered_map<string, bool> is_node_in_cluster;
+ for (const NodeDef& node : graph_.node()) {
+ is_node_in_cluster[node.name()] = false;
+ }
+
+ std::vector<string> cluster_names;
+ CHECK(FindCluster(svdf_cluster_factory_, graph_, &is_node_in_cluster,
+ &clusters_));
+
+ for (const std::unique_ptr<Cluster>& cluster : clusters_) {
+ cluster_names.push_back(cluster->GetName());
+ cluster->CreateNodes();
+ }
+
+ EXPECT_THAT(cluster_names,
+ testing::UnorderedElementsAreArray({"Svdf1", "Svdf2"}));
+
+ std::vector<string> new_node_names;
+ std::vector<float> content_array(3);
+ for (const std::unique_ptr<Cluster>& cluster : clusters_) {
+ // After CreateNodes in each cluster we have three nodes: Svdf,
+ // weights_feature and weights_time.
+ CHECK_EQ(cluster->GetNewNodes().size(), 3);
+ for (const std::unique_ptr<tensorflow::NodeDef>& node :
+ cluster->GetNewNodes()) {
+ new_node_names.push_back(node->name());
+ if (node->op() == "Const") {
+ CHECK_EQ(node->attr().at("dtype").type(), tensorflow::DT_FLOAT);
+ toco::port::CopyToBuffer(
+ node->attr().at("value").tensor().tensor_content(),
+ reinterpret_cast<char*>(content_array.data()));
+ EXPECT_THAT(content_array,
+ testing::UnorderedElementsAreArray({0.1, 0.2, 0.3}));
+ } else {
+ // Checking the Svdf node attributes (rank and activation type) are
+ // correct.
+ if (node->name() == "Svdf1") {
+ CHECK_EQ(node->attr().at("Rank").i(), 1);
+ } else if (node->name() == "Svdf2") {
+ CHECK_EQ(node->attr().at("Rank").i(), 2);
+ }
+ CHECK_EQ(node->attr().at("ActivationFunction").s(), "Relu");
+ }
+ }
+ }
+ EXPECT_THAT(new_node_names, testing::UnorderedElementsAreArray(
+ {"Svdf2/SVDF_weights_feature/part_0",
+ "Svdf2/SVDF_weights_time/part_0", "Svdf2",
+ "Svdf1/SVDF_weights_feature/part_0",
+ "Svdf1/SVDF_weights_time/part_0", "Svdf1"}));
+}
+
+} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/contrib/lite/toco/tensorflow_util.cc
new file mode 100644
index 0000000000..82e2800ca2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_util.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
+
+#include <string.h>
+#include <memory>
+#include <set>
+
+#ifdef GOOGLE_PLATFORM
+#include "file/logging/log_lines.h"
+#endif
+#include "google/protobuf/map.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+using tensorflow::AttrValue;
+using tensorflow::GraphDef;
+
+void LogDumpGraphDef(int log_level, const string& message,
+ const GraphDef& tf_graph) {
+ if (!VLOG_IS_ON(log_level)) {
+ return;
+ }
+ std::set<string> ops;
+ for (const auto& node : tf_graph.node()) {
+ ops.insert(node.op());
+ }
+ string dump;
+ toco::port::AppendF(&dump, R"MSG(
+BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s)
+There are %d nodes.
+There are %zu different op types:
+)MSG", message, tf_graph.node_size(), ops.size());
+ for (const auto& op : ops) {
+ toco::port::AppendF(&dump, " %s\n", op);
+ }
+ dump.append(R"MSG(
+PROTO DUMP
+)MSG");
+ for (const auto& node : tf_graph.node()) {
+ toco::port::AppendF(&dump, R"MSG(
+BEGIN NODE: name = %s
+ op = %s
+ inputs = [
+)MSG", node.name(), node.op());
+ for (const auto& input : node.input()) {
+ toco::port::AppendF(&dump, " %s\n", input);
+ }
+ dump.append(" ]\n");
+ for (const auto& attr : node.attr()) {
+ toco::port::AppendF(&dump, " ATTR: name = %s\n", attr.first);
+ if (attr.second.value_case() == AttrValue::kFunc) {
+ dump.append(" func\n");
+ } else if (attr.second.value_case() == AttrValue::kPlaceholder) {
+ toco::port::AppendF(&dump, " placeholder: %s\n",
+ attr.second.placeholder());
+ } else if (attr.second.value_case() == AttrValue::kS) {
+ dump.append(" string:\n");
+ dump.append(R"MSG(
+ BEGIN EMBEDDED STRING
+)MSG");
+ const auto& lines = absl::StrSplit(attr.second.s(), '\n');
+ for (const auto& line : lines) {
+ toco::port::AppendF(&dump, " %s\n", line);
+ }
+ dump.append(R"MSG(
+ END EMBEDDED STRING
+)MSG");
+ } else if (attr.second.value_case() == AttrValue::kI) {
+ toco::port::AppendF(&dump, " int: %lld\n", attr.second.i());
+ } else if (attr.second.value_case() == AttrValue::kF) {
+ toco::port::AppendF(&dump, " float: %g\n", attr.second.f());
+ } else if (attr.second.value_case() == AttrValue::kB) {
+ toco::port::AppendF(&dump, " bool: %s\n",
+ attr.second.b() ? "true" : "false");
+ } else if (attr.second.value_case() == AttrValue::kType) {
+ toco::port::AppendF(&dump, " type: %s\n",
+ tensorflow::DataType_Name(attr.second.type()));
+ } else if (attr.second.value_case() == AttrValue::kShape) {
+ dump.append(" shape: [ ");
+ const auto& shape = attr.second.shape();
+ for (int i = 0; i < shape.dim_size(); i++) {
+ toco::port::AppendF(&dump, "%lld ", shape.dim(i).size());
+ }
+ dump.append("]\n");
+ } else if (attr.second.value_case() == AttrValue::kTensor) {
+ const auto& tensor = attr.second.tensor();
+ dump.append(" TENSOR:\n");
+ toco::port::AppendF(&dump, " type: %s\n",
+ tensorflow::DataType_Name(tensor.dtype()));
+ const auto& shape = tensor.tensor_shape();
+ dump.append(" shape: [ ");
+ for (int i = 0; i < shape.dim_size(); i++) {
+ toco::port::AppendF(&dump, "%lld ", shape.dim(i).size());
+ }
+ dump.append("]\n");
+ if (!tensor.tensor_content().empty()) {
+ toco::port::AppendF(&dump, " tensor_content: %zu bytes\n",
+ tensor.tensor_content().size());
+ }
+ if (tensor.dtype() == tensorflow::DT_INT32) {
+ CHECK_EQ(0, tensor.tensor_content().size() % sizeof(int32));
+ const int size = tensor.tensor_content().size() / sizeof(int32);
+ std::vector<int32> data(size);
+ toco::port::CopyToBuffer(tensor.tensor_content(),
+ reinterpret_cast<char*>(data.data()));
+ const int kMaxValsToPrint = 4;
+ dump.append(" tensor_content as ints: [ ");
+ for (int i = 0; i < kMaxValsToPrint && i < size; i++) {
+ toco::port::AppendF(&dump, "%d ", data[i]);
+ }
+ if (size > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.dtype() == tensorflow::DT_FLOAT) {
+ CHECK_EQ(0, tensor.tensor_content().size() % sizeof(float));
+ const int size = tensor.tensor_content().size() / sizeof(float);
+ std::vector<float> data(size);
+ toco::port::CopyToBuffer(tensor.tensor_content(),
+ reinterpret_cast<char*>(data.data()));
+ const int kMaxValsToPrint = 4;
+ dump.append(" tensor_content as floats: [ ");
+ for (int i = 0; i < kMaxValsToPrint && i < size; i++) {
+ toco::port::AppendF(&dump, "%g ", data[i]);
+ }
+ if (size > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.int_val_size()) {
+ toco::port::AppendF(&dump, " int_val: %d ints: [ ",
+ tensor.int_val_size());
+ const int kMaxValsToPrint = 4;
+ for (int i = 0; i < kMaxValsToPrint && i < tensor.int_val_size();
+ i++) {
+ toco::port::AppendF(&dump, "%d ", tensor.int_val(i));
+ }
+ if (tensor.int_val_size() > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.float_val_size()) {
+ toco::port::AppendF(&dump, " float_val: %d floats: [ ",
+ tensor.float_val_size());
+ const int kMaxValsToPrint = 4;
+ for (int i = 0; i < kMaxValsToPrint && i < tensor.float_val_size();
+ i++) {
+ toco::port::AppendF(&dump, "%g ", tensor.float_val(i));
+ }
+ if (tensor.float_val_size() > kMaxValsToPrint) {
+ dump.append("... ");
+ }
+ dump.append("]\n");
+ }
+ if (tensor.string_val_size()) {
+ toco::port::AppendF(&dump, " string_val: %d strings\n",
+ tensor.string_val_size());
+ }
+ } else if (attr.second.value_case() == AttrValue::kList) {
+ dump.append(" LIST\n");
+ }
+ }
+ dump.append("END NODE\n");
+ }
+ toco::port::AppendF(&dump, "END DUMP OF TENSORFLOW GRAPHDEF (%s)\n", message);
+#if defined(GOOGLE_PLATFORM)
+ VLOG_LINES(log_level, dump);
+#else
+ VLOG(log_level) << dump;
+#endif
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/contrib/lite/toco/tensorflow_util.h
new file mode 100644
index 0000000000..152b4f7a72
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tensorflow_util.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+
+void LogDumpGraphDef(int log_level, const string& message,
+ const tensorflow::GraphDef& tf_graph);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
new file mode 100644
index 0000000000..e910e3957f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -0,0 +1,142 @@
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "operator",
+ srcs = [
+ "operator.cc",
+ ],
+ hdrs = [
+ "builtin_operator.h",
+ "custom_operator.h",
+ "operator.h",
+ "simple_operator.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "operator_test",
+ srcs = [
+ "operator_test.cc",
+ ],
+ deps = [
+ ":operator",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "types",
+ srcs = [
+ "types.cc",
+ ],
+ hdrs = [
+ "types.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ ],
+)
+
+tf_cc_test(
+ name = "types_test",
+ srcs = [
+ "types_test.cc",
+ ],
+ deps = [
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "export",
+ srcs = [
+ "export.cc",
+ ],
+ hdrs = [
+ "export.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":operator",
+ ":types",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "export_test",
+ srcs = [
+ "export_test.cc",
+ ],
+ deps = [
+ ":export",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "import",
+ srcs = [
+ "import.cc",
+ ],
+ hdrs = [
+ "import.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":operator",
+ ":types",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:model",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "import_test",
+ srcs = [
+ "import_test.cc",
+ ],
+ deps = [
+ ":import",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers//:flatbuffers",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
new file mode 100644
index 0000000000..93cc79ddb6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Builtin operators have special TF Lite objects describing their options.
+// This class has the boilerplate code for creating those.
+//
+// Template arguments:
+// - T1 must derive from ::toco::Operator.
+// - T2 must be one of TF Lite's objects defining Builtin Options, such as
+// ::tflite::Conv2DOptions.
+template <typename T1, typename T2, ::tflite::BuiltinOptions TfLiteEnum>
+class BuiltinOperator : public BaseOperator {
+ public:
+ using TocoOperator = T1;
+ using TfLiteOptions = T2;
+
+ BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type)
+ : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {}
+
+ // Build the configuration object in the given flatbuffer builder. Return
+ // its offset.
+ virtual flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const = 0;
+
+ // Read options from the TF Lite object and set the corresponding values in
+ // the tf.mini operator.
+ virtual void ReadOptions(const TfLiteOptions& opt,
+ TocoOperator* op) const = 0;
+
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto options = WriteOptions(static_cast<const TocoOperator&>(op), builder);
+ return Options::Builtin(TfLiteEnum, options.Union());
+ }
+
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ auto op = absl::make_unique<TocoOperator>();
+ auto* options = static_cast<const TfLiteOptions*>(builtin_options);
+ if (options) {
+ ReadOptions(*options, op.get());
+ }
+ return std::unique_ptr<Operator>(op.release());
+ }
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/contrib/lite/toco/tflite/custom_operator.h
new file mode 100644
index 0000000000..1a4bfac7d4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/custom_operator.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+
+#include "flatbuffers/flexbuffers.h"
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Custom operators have a generic byte buffer describing their options. This
+// class provides the boilerplate code for populating those options using
+// flexbuffers. Note that most of toco's operators will likely be supported
+// as builtin operators in TF Lite.
+//
+// Template argument T must derive from ::toco::Operator.
+template <typename T>
+class CustomOperator : public BaseOperator {
+ public:
+ using TocoOperator = T;
+ using BaseOperator::BaseOperator;
+
+ // Populate the given flexbuffer with options obtained from the tf.mini
+ // operator.
+ virtual void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const {}
+
+ // Set options in the given tf.mini operator using values from the flexbuffer
+ // map.
+ virtual void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const {}
+
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ flexbuffers::Builder fbb;
+ fbb.Map(
+ [&]() { WriteOptions(static_cast<const TocoOperator&>(op), &fbb); });
+ fbb.Finish();
+ return Options::Custom(builder->CreateVector(fbb.GetBuffer()));
+ }
+
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ auto op = absl::make_unique<TocoOperator>();
+ if (custom_options) {
+ auto flexbuffer_map =
+ flexbuffers::GetRoot(custom_options->data(), custom_options->size())
+ .AsMap();
+ ReadOptions(flexbuffer_map, op.get());
+ }
+ return std::unique_ptr<Operator>(op.release());
+ }
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
new file mode 100644
index 0000000000..beda710614
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -0,0 +1,322 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/export.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace toco {
+
+namespace tflite {
+
+using ::tflite::Buffer;
+using ::tflite::BuiltinOperator;
+using ::tflite::BuiltinOperator_CUSTOM;
+using ::tflite::BuiltinOperator_MAX;
+using ::tflite::BuiltinOperator_MIN;
+using ::tflite::CreateBuffer;
+using ::tflite::CreateModel;
+using ::tflite::CreateOperator;
+using ::tflite::CreateTensor;
+using ::tflite::Operator;
+using ::tflite::OperatorCode;
+using ::tflite::SubGraph;
+using ::tflite::Tensor;
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+namespace {
+
+details::OperatorKey GetOperatorKey(const ::toco::Operator& op) {
+ string custom_code;
+ if (op.type == OperatorType::kTensorFlowUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ custom_code = unsupported_op.tensorflow_op;
+ }
+ return details::OperatorKey(op.type, custom_code);
+}
+
+} // Anonymous namespace.
+
+namespace details {
+
+void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
+ // First find a list of unique array names.
+ std::set<string> names;
+ for (const auto& array_pair : model.arrays) {
+ names.insert(array_pair.first);
+ }
+
+ // Now assign indices to them and fill in the map.
+ int index = 0;
+ for (const auto& name : names) {
+ (*tensors_map)[name] = index;
+ ++index;
+ }
+}
+
+void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) {
+ // First find a list of unique operator types.
+ std::set<OperatorKey> keys;
+ for (const auto& op : model.operators) {
+ keys.insert(GetOperatorKey(*op));
+ }
+ // Now assign indices to them and fill in the map.
+ int index = 0;
+ for (const auto& key : keys) {
+ (*operators_map)[key] = index;
+ ++index;
+ }
+}
+} // namespace details
+
+Offset<Vector<Offset<Tensor>>> ExportTensors(
+ const Model& model, const details::TensorsMap& tensors_map,
+ FlatBufferBuilder* builder, std::vector<const Array*>* buffers_to_write) {
+ // In the end we will need to produce a vector sorted by the indices of the
+ // tensors in the tensors_map.
+ std::map<int, Offset<Tensor>> ordered_tensors;
+
+ for (const auto& array_pair : model.arrays) {
+ const string& tensor_name = array_pair.first;
+ const toco::Array& array = *array_pair.second;
+
+ int buffer_index = buffers_to_write->size();
+ auto type = DataType::Serialize(array.data_type);
+ buffers_to_write->push_back(&array);
+
+ std::vector<int> shape;
+ if (array.has_shape()) {
+ for (int d : array.shape().dims()) {
+ shape.push_back(d);
+ }
+ }
+
+ Offset<Vector<float>> min;
+ Offset<Vector<float>> max;
+ Offset<Vector<float>> scale;
+ Offset<Vector<int64_t>> zero_point;
+ if (array.minmax) {
+ min = builder->CreateVector(
+ std::vector<float>{static_cast<float>(array.minmax->min)});
+ max = builder->CreateVector(
+ std::vector<float>{static_cast<float>(array.minmax->max)});
+ }
+ if (array.quantization_params) {
+ scale = builder->CreateVector(std::vector<float>{
+ static_cast<float>(array.quantization_params->scale)});
+ zero_point = builder->CreateVector(
+ std::vector<int64_t>{array.quantization_params->zero_point});
+ }
+ auto q_param = ::tflite::CreateQuantizationParameters(*builder, min, max,
+ scale, zero_point);
+
+ int index = tensors_map.at(tensor_name);
+ ordered_tensors[index] =
+ CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index,
+ builder->CreateString(tensor_name), q_param);
+ }
+
+ std::vector<Offset<Tensor>> tensor_vector;
+ tensor_vector.reserve(ordered_tensors.size());
+ for (const auto& tensor : ordered_tensors) {
+ tensor_vector.push_back(tensor.second);
+ }
+
+ return builder->CreateVector(tensor_vector);
+}
+
+Offset<Vector<int32_t>> ExportInputTensors(
+ const Model& model, const details::TensorsMap& tensors_map,
+ FlatBufferBuilder* builder) {
+ std::vector<int32_t> inputs;
+ for (const auto& input : model.flags.input_arrays()) {
+ inputs.push_back(tensors_map.at(input.name()));
+ }
+ return builder->CreateVector<int32_t>(inputs);
+}
+
+Offset<Vector<int32_t>> ExportOutputTensors(
+ const Model& model, const details::TensorsMap& tensors_map,
+ FlatBufferBuilder* builder) {
+ std::vector<int32_t> outputs;
+ for (const string& output : model.flags.output_arrays()) {
+ outputs.push_back(tensors_map.at(output));
+ }
+ return builder->CreateVector<int32_t>(outputs);
+}
+
+Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
+ const Model& model,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
+ std::set<string>* error_summary) {
+ // Map from operator name to TF Lite enum value, for all builtins.
+ std::map<string, BuiltinOperator> builtin_ops;
+ for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
+ BuiltinOperator op = static_cast<BuiltinOperator>(i);
+ string name = EnumNameBuiltinOperator(op);
+ if (op != BuiltinOperator_CUSTOM && !name.empty()) {
+ builtin_ops[name] = op;
+ }
+ }
+
+ // We will need to produce a vector of codes in the same order as they
+ // appear in the operators_map.
+ std::map<int, Offset<OperatorCode>> ordered_opcodes;
+
+ for (const auto& op : model.operators) {
+ const details::OperatorKey operator_key = GetOperatorKey(*op);
+ int op_index = operators_map.at(operator_key);
+
+ if (ops_by_type.count(op->type) == 0) {
+ LOG(FATAL) << "Unsupported operator: " << HelpfulOperatorTypeName(*op);
+ }
+
+ string name = ops_by_type.at(op->type)->name();
+ if (builtin_ops.count(name) > 0) {
+ ordered_opcodes[op_index] =
+ CreateOperatorCode(*builder, builtin_ops[name], 0);
+ } else {
+ // If use the custom operation code if it's available in the OperatorKey.
+ if (!operator_key.custom_code.empty()) {
+ name = operator_key.custom_code;
+ }
+ if (error_summary) {
+ error_summary->insert(name);
+ }
+ ordered_opcodes[op_index] = CreateOperatorCode(
+ *builder, BuiltinOperator_CUSTOM, builder->CreateString(name));
+ }
+ }
+
+ std::vector<Offset<OperatorCode>> opcode_vector;
+ opcode_vector.reserve(ordered_opcodes.size());
+ for (const auto& opcode : ordered_opcodes) {
+ opcode_vector.push_back(opcode.second);
+ }
+
+ return builder->CreateVector(opcode_vector);
+}
+
+Offset<Vector<Offset<Operator>>> ExportOperators(
+ const Model& model,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ const details::OperatorsMap& operators_map,
+ const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) {
+ // The operators are in execution order, so we just follow tf.mini order.
+ std::vector<Offset<Operator>> op_vector;
+ for (const auto& op : model.operators) {
+ if (ops_by_type.count(op->type) == 0) {
+ LOG(FATAL) << "Op type '" << OperatorTypeName(op->type)
+ << "' not supported";
+ }
+
+ std::vector<int32_t> inputs;
+ for (const string& input : op->inputs) {
+ inputs.push_back(tensors_map.at(input));
+ }
+
+ std::vector<int32_t> outputs;
+ for (const string& output : op->outputs) {
+ outputs.push_back(tensors_map.at(output));
+ }
+
+ auto options = ops_by_type.at(op->type)->Serialize(*op, builder);
+ int op_index = operators_map.at(GetOperatorKey(*op));
+ // The only supported CustomOptionFormat is FLEXBUFFERS now.
+ op_vector.push_back(CreateOperator(
+ *builder, op_index, builder->CreateVector(inputs),
+ builder->CreateVector(outputs), options.type, options.builtin,
+ options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS));
+ }
+
+ return builder->CreateVector(op_vector);
+}
+
+Offset<Vector<Offset<Buffer>>> ExportBuffers(
+ const Model& model, const std::vector<const Array*>& buffers_to_write,
+ FlatBufferBuilder* builder) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ size_t index = 0;
+ for (const Array* array_ptr : buffers_to_write) {
+ const Array& array = *array_ptr;
+ Offset<Vector<uint8_t>> data_buffer = DataBuffer::Serialize(array, builder);
+ buffer_vector.push_back(CreateBuffer(*builder, data_buffer));
+ index++;
+ }
+ return builder->CreateVector(buffer_vector);
+}
+
+void Export(const Model& model, bool allow_custom_ops,
+ string* output_file_contents) {
+ flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+
+ details::TensorsMap tensors_map;
+ details::LoadTensorsMap(model, &tensors_map);
+
+ details::OperatorsMap operators_map;
+ details::LoadOperatorsMap(model, &operators_map);
+
+ std::vector<const Array*> buffers_to_write;
+ Array empty_array;
+ buffers_to_write.push_back(&empty_array);
+
+ auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write);
+ auto inputs = ExportInputTensors(model, tensors_map, &builder);
+ auto outputs = ExportOutputTensors(model, tensors_map, &builder);
+
+ std::set<string> error_summary;
+ auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
+ &builder, &error_summary);
+ if (!allow_custom_ops && !error_summary.empty()) {
+ LOG(QFATAL) << "Some of the operators in the model are not supported by "
+ "the standard TensorFlow Lite runtime. If you have a custom "
+ "implementation for them you can disable this error with "
+ "--allow_custom_ops. Here is a list of operators for which "
+ "you will need custom implementations: "
+ << absl::StrJoin(error_summary, ", ") << ".";
+ }
+
+ auto ops =
+ ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder);
+
+ // TODO(aselle): add support to toco for multiple subgraphs.
+ auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops);
+ std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
+
+ auto buffers = ExportBuffers(model, buffers_to_write, &builder);
+ auto description = builder.CreateString("TOCO Converted.");
+ auto new_model_location =
+ CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+ builder.CreateVector(subgraphs), description, buffers);
+ ::tflite::FinishModelBuffer(builder, new_model_location);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ int size = builder.GetSize();
+ *output_file_contents = string(reinterpret_cast<const char*>(buffer), size);
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
new file mode 100644
index 0000000000..44012b7126
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
+// result in the given string.
+void Export(const Model& model, bool allow_custom_ops,
+ string* output_file_contents);
+// This if backward-compatibility.
+inline void Export(const Model& model, string* output_file_contents) {
+ Export(model, true, output_file_contents);
+}
+
+namespace details {
+
+// A maps from tensor name to its final position in the TF Lite buffer.
+using TensorsMap = std::unordered_map<string, int>;
+
+// A key to identify an operator.
+// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to
+// identify which operation is used.
+struct OperatorKey {
+ OperatorKey(OperatorType type, const std::string& custom_code)
+ : type(type), custom_code(custom_code) {}
+ const OperatorType type;
+ const std::string custom_code;
+
+ bool operator<(const OperatorKey& other) const {
+ if (type < other.type) return true;
+ if (type > other.type) return false;
+ return custom_code < other.custom_code;
+ }
+
+ bool operator==(const OperatorKey& other) const {
+ return type == other.type && custom_code == other.custom_code;
+ }
+
+ struct Hash {
+ std::size_t operator()(const OperatorKey& key) const {
+ return std::hash<size_t>()(static_cast<size_t>(key.type)) ^
+ std::hash<std::string>()(key.custom_code);
+ }
+ };
+};
+
+// A maps from operator type to its final position in the TF Lite buffer.
+using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
+
+void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
+void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map);
+
+} // namespace details
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
new file mode 100644
index 0000000000..e395645383
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/export.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+class ExportTest : public ::testing::Test {
+ protected:
+ // This is a very simplistic model. We are not interested in testing all the
+ // details here, since tf.mini's testing framework will be exercising all the
+ // conversions multiple times, and the conversion of operators is tested by
+ // separate unittests.
+ void BuildTestModel() {
+ input_model_.GetOrCreateArray("tensor_one");
+ input_model_.GetOrCreateArray("tensor_two");
+ input_model_.operators.emplace_back(new ConvOperator);
+ input_model_.operators.emplace_back(new AddOperator);
+ auto unsupported_operator = new TensorFlowUnsupportedOperator;
+ unsupported_operator->tensorflow_op = "MyCrazyOp";
+ input_model_.operators.emplace_back(unsupported_operator);
+ }
+
+ Model input_model_;
+};
+
+TEST_F(ExportTest, LoadTensorsMap) {
+ BuildTestModel();
+
+ details::TensorsMap tensors;
+ details::LoadTensorsMap(input_model_, &tensors);
+ EXPECT_EQ(0, tensors["tensor_one"]);
+ EXPECT_EQ(1, tensors["tensor_two"]);
+}
+
+TEST_F(ExportTest, LoadOperatorsMap) {
+ BuildTestModel();
+
+ details::OperatorsMap operators;
+ details::LoadOperatorsMap(input_model_, &operators);
+ EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]);
+ EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]);
+ EXPECT_EQ(2, operators[details::OperatorKey(
+ OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]);
+}
+
+// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators.
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc
new file mode 100644
index 0000000000..bbf201fd28
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/import.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/import.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+namespace toco {
+
+namespace tflite {
+
+namespace details {
+void LoadTensorsTable(const ::tflite::Model& input_model,
+ TensorsTable* tensors_table) {
+ // TODO(aselle): add support to toco for multiple subgraphs.
+ auto tensors = (*input_model.subgraphs())[0]->tensors();
+ if (!tensors) return;
+ for (const auto* tensor : *tensors) {
+ tensors_table->push_back(tensor->name()->c_str());
+ }
+}
+
+void LoadOperatorsTable(const ::tflite::Model& input_model,
+ OperatorsTable* operators_table) {
+ auto opcodes = input_model.operator_codes();
+ if (!opcodes) return;
+ for (const auto* opcode : *opcodes) {
+ if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
+ operators_table->push_back(
+ EnumNameBuiltinOperator(opcode->builtin_code()));
+ } else {
+ operators_table->push_back(opcode->custom_code()->c_str());
+ }
+ }
+}
+} // namespace details
+
+void ImportTensors(const ::tflite::Model& input_model, Model* model) {
+ auto tensors = (*input_model.subgraphs())[0]->tensors();
+ auto* buffers = input_model.buffers();
+ // auto tensors = input_model.tensors();
+ if (!tensors) return;
+ for (const auto* input_tensor : *tensors) {
+ Array& array = model->GetOrCreateArray(input_tensor->name()->c_str());
+ array.data_type = DataType::Deserialize(input_tensor->type());
+ int buffer_index = input_tensor->buffer();
+ auto* buffer = buffers->Get(buffer_index);
+ DataBuffer::Deserialize(*input_tensor, *buffer, &array);
+
+ auto shape = input_tensor->shape();
+ if (shape) {
+ for (int i = 0; i < shape->Length(); ++i) {
+ auto d = shape->Get(i);
+ array.mutable_shape()->mutable_dims()->push_back(d);
+ }
+ }
+
+ auto quantization = input_tensor->quantization();
+ if (quantization) {
+ // Note that tf.mini only supports a single quantization parameters for
+ // the whole array.
+ if (quantization->min() && quantization->max()) {
+ CHECK_EQ(1, quantization->min()->Length());
+ CHECK_EQ(1, quantization->max()->Length());
+ MinMax& minmax = array.GetOrCreateMinMax();
+ minmax.min = quantization->min()->Get(0);
+ minmax.max = quantization->max()->Get(0);
+ }
+ if (quantization->scale() && quantization->zero_point()) {
+ CHECK_EQ(1, quantization->scale()->Length());
+ CHECK_EQ(1, quantization->zero_point()->Length());
+ QuantizationParams& q = array.GetOrCreateQuantizationParams();
+ q.scale = quantization->scale()->Get(0);
+ q.zero_point = quantization->zero_point()->Get(0);
+ }
+ }
+ }
+}
+
+void ImportOperators(
+ const ::tflite::Model& input_model,
+ const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name,
+ const details::TensorsTable& tensors_table,
+ const details::OperatorsTable& operators_table, Model* model) {
+ // TODO(aselle): add support for multiple subgraphs.
+ auto ops = (*input_model.subgraphs())[0]->operators();
+
+ if (!ops) return;
+ for (const auto* input_op : *ops) {
+ int index = input_op->opcode_index();
+ if (index < 0 || index > operators_table.size()) {
+ LOG(FATAL) << "Index " << index << " must be between zero and "
+ << operators_table.size();
+ }
+ string opname = operators_table.at(index);
+ if (ops_by_name.count(opname) == 0) {
+ LOG(FATAL) << "Op '" << opname << "' not supported";
+ }
+
+ auto new_op = ops_by_name.at(opname)->Deserialize(
+ input_op->builtin_options(), input_op->custom_options());
+ model->operators.emplace_back(new_op.release());
+ auto* op = model->operators.back().get();
+
+ auto inputs = input_op->inputs();
+ for (int i = 0; i < inputs->Length(); i++) {
+ auto input_index = inputs->Get(i);
+ const string& input_name = tensors_table.at(input_index);
+ op->inputs.push_back(input_name);
+ }
+ auto outputs = input_op->outputs();
+ for (int i = 0; i < outputs->Length(); i++) {
+ auto output_index = outputs->Get(i);
+ const string& output_name = tensors_table.at(output_index);
+ op->outputs.push_back(output_name);
+ }
+ }
+}
+
+void ImportIOTensors(const ::tflite::Model& input_model,
+ const details::TensorsTable& tensors_table, Model* model) {
+ auto inputs = (*input_model.subgraphs())[0]->inputs();
+ if (inputs) {
+ for (int input : *inputs) {
+ const string& input_name = tensors_table.at(input);
+ model->flags.add_input_arrays()->set_name(input_name);
+ }
+ }
+
+ auto outputs = (*input_model.subgraphs())[0]->outputs();
+ if (outputs) {
+ for (int output : *outputs) {
+ const string& output_name = tensors_table.at(output);
+ model->flags.add_output_arrays(output_name);
+ }
+ }
+}
+
+std::unique_ptr<Model> Import(const ModelFlags& model_flags,
+ const string& input_file_contents) {
+ const ::tflite::Model* input_model =
+ ::tflite::GetModel(input_file_contents.data());
+
+ // Full list of all known operators.
+ const auto ops_by_name = BuildOperatorByNameMap();
+
+ if (input_model->subgraphs()->size() != 1) {
+ LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now.";
+ }
+ std::unique_ptr<Model> model;
+ model.reset(new Model);
+
+ details::TensorsTable tensors_table;
+ details::LoadTensorsTable(*input_model, &tensors_table);
+
+ details::OperatorsTable operators_table;
+ details::LoadOperatorsTable(*input_model, &operators_table);
+
+ ImportTensors(*input_model, model.get());
+ ImportOperators(*input_model, ops_by_name, tensors_table, operators_table,
+ model.get());
+ ImportIOTensors(*input_model, tensors_table, model.get());
+
+ return model;
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/contrib/lite/toco/tflite/import.h
new file mode 100644
index 0000000000..3c27a2843c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/import.h
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Parse the given string as TF Lite flatbuffer and return a new tf.mini model.
+std::unique_ptr<Model> Import(const ModelFlags &model_flags,
+ const string &input_file_contents);
+
+namespace details {
+
+// The names of all tensors found in a TF Lite model.
+using TensorsTable = std::vector<string>;
+
+// The names of all operators found in TF Lite model. If the operator is
+// builtin, the string representation of the corresponding enum value is used
+// as name.
+using OperatorsTable = std::vector<string>;
+
+void LoadTensorsTable(const ::tflite::Model &input_model,
+ TensorsTable *tensors_table);
+void LoadOperatorsTable(const ::tflite::Model &input_model,
+ OperatorsTable *operators_table);
+
+} // namespace details
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc
new file mode 100644
index 0000000000..309fa6d7f6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/import.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class ImportTest : public ::testing::Test {
+ protected:
+ template <typename T>
+ flatbuffers::Offset<flatbuffers::Vector<unsigned char>> CreateDataVector(
+ const std::vector<T>& data) {
+ return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
+ sizeof(T) * data.size());
+ }
+ // This is a very simplistic model. We are not interested in testing all the
+ // details here, since tf.mini's testing framework will be exercising all the
+ // conversions multiple times, and the conversion of operators is tested by
+ // separate unittests.
+ void BuildTestModel() {
+ // The tensors
+ auto q = ::tflite::CreateQuantizationParameters(
+ builder_,
+ /*min=*/builder_.CreateVector<float>({0.1f}),
+ /*max=*/builder_.CreateVector<float>({0.2f}),
+ /*scale=*/builder_.CreateVector<float>({0.3f}),
+ /*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
+ auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
+ auto buf1 =
+ ::tflite::CreateBuffer(builder_, CreateDataVector<float>({1.0f, 2.0f}));
+ auto buf2 =
+ ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f}));
+ auto buffers = builder_.CreateVector(
+ std::vector<flatbuffers::Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
+ auto t1 = ::tflite::CreateTensor(builder_,
+ builder_.CreateVector<int>({1, 2, 3, 4}),
+ ::tflite::TensorType_FLOAT32, 1,
+ builder_.CreateString("tensor_one"), q);
+ auto t2 =
+ ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
+ ::tflite::TensorType_FLOAT32, 2,
+ builder_.CreateString("tensor_two"), q);
+ auto tensors = builder_.CreateVector(
+ std::vector<flatbuffers::Offset<::tflite::Tensor>>({t1, t2}));
+
+ // The operator codes.
+ auto c1 =
+ ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM,
+ builder_.CreateString("custom_op_one"));
+ auto c2 = ::tflite::CreateOperatorCode(
+ builder_, ::tflite::BuiltinOperator_CONV_2D, 0);
+ auto opcodes = builder_.CreateVector(
+ std::vector<flatbuffers::Offset<::tflite::OperatorCode>>({c1, c2}));
+
+ auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0);
+ std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vector(
+ {subgraph});
+ auto subgraphs = builder_.CreateVector(subgraph_vector);
+ auto s = builder_.CreateString("");
+ builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
+ opcodes, subgraphs, s, buffers));
+
+ input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
+ }
+ string InputModelAsString() {
+ return string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
+ builder_.GetSize());
+ }
+ flatbuffers::FlatBufferBuilder builder_;
+ // const uint8_t* buffer_ = nullptr;
+ const ::tflite::Model* input_model_ = nullptr;
+};
+
+TEST_F(ImportTest, LoadTensorsTable) {
+ BuildTestModel();
+
+ details::TensorsTable tensors;
+ details::LoadTensorsTable(*input_model_, &tensors);
+ EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
+}
+
+TEST_F(ImportTest, LoadOperatorsTable) {
+ BuildTestModel();
+
+ details::OperatorsTable operators;
+ details::LoadOperatorsTable(*input_model_, &operators);
+ EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D"));
+}
+
+TEST_F(ImportTest, Tensors) {
+ BuildTestModel();
+
+ auto model = Import(ModelFlags(), InputModelAsString());
+
+ ASSERT_GT(model->arrays.count("tensor_one"), 0);
+ Array& a1 = model->GetArray("tensor_one");
+ EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
+ EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
+ ElementsAre(1.0f, 2.0f));
+ ASSERT_TRUE(a1.has_shape());
+ EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4));
+
+ const auto& mm = a1.minmax;
+ ASSERT_TRUE(mm.get());
+ EXPECT_FLOAT_EQ(0.1, mm->min);
+ EXPECT_FLOAT_EQ(0.2, mm->max);
+
+ const auto& q = a1.quantization_params;
+ ASSERT_TRUE(q.get());
+ EXPECT_FLOAT_EQ(0.3, q->scale);
+ EXPECT_EQ(100, q->zero_point);
+}
+
+// TODO(ahentz): still need tests for Operators and IOTensors.
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
new file mode 100644
index 0000000000..8a33500ddc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -0,0 +1,627 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+namespace tflite {
+
+class AveragePool
+ : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
+ ::tflite::BuiltinOptions_Pool2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, op.kwidth,
+ op.kheight, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->kwidth = options.filter_width();
+ op->kheight = options.filter_height();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Convolution
+ : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
+ ::tflite::BuiltinOptions_Conv2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class DepthwiseConvolution
+ : public BuiltinOperator<DepthwiseConvOperator,
+ ::tflite::DepthwiseConv2DOptions,
+ ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateDepthwiseConv2DOptions(
+ *builder, padding, op.stride_width, op.stride_height,
+ op.depth_multiplier, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->depth_multiplier = options.depth_multiplier();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
+ ::tflite::BuiltinOptions_AddOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateAddOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Cast : public CustomOperator<CastOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
+ fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
+ op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
+ }
+};
+
+class Concatenation
+ : public BuiltinOperator<ConcatenationOperator,
+ ::tflite::ConcatenationOptions,
+ ::tflite::BuiltinOptions_ConcatenationOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateConcatenationOptions(*builder, op.concat_dim);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->concat_dim = options.axis();
+ }
+};
+
+class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("block_size", op.block_size);
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->block_size = m["block_size"].AsInt64();
+ }
+};
+
+class FakeQuant : public CustomOperator<FakeQuantOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Float("min", op.minmax->min);
+ fbb->Float("max", op.minmax->max);
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ auto* minmax = new MinMax;
+ minmax->min = m["min"].AsFloat();
+ minmax->max = m["max"].AsFloat();
+ op->minmax.reset(minmax);
+ }
+};
+
+class FullyConnected
+ : public BuiltinOperator<FullyConnectedOperator,
+ ::tflite::FullyConnectedOptions,
+ ::tflite::BuiltinOptions_FullyConnectedOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
+ ::tflite::BuiltinOptions_SVDFOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ op->rank = options.rank();
+ }
+};
+
+class L2Normalization
+ : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
+ ::tflite::BuiltinOptions_L2NormOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateL2NormOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
+ ::tflite::BuiltinOptions_Pool2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, op.kwidth,
+ op.kheight, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->kwidth = options.filter_width();
+ op->kheight = options.filter_height();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class LocalResponseNormalization
+ : public BuiltinOperator<
+ LocalResponseNormalizationOperator,
+ ::tflite::LocalResponseNormalizationOptions,
+ ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateLocalResponseNormalizationOptions(
+ *builder, op.range, op.bias, op.alpha, op.beta);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->range = options.radius();
+ op->bias = options.bias();
+ op->alpha = options.alpha();
+ op->beta = options.beta();
+ }
+};
+
+class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
+ ::tflite::BuiltinOptions_Pool2DOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, op.kwidth,
+ op.kheight, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->kwidth = options.filter_width();
+ op->kheight = options.filter_height();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
+ ::tflite::BuiltinOptions_MulOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateMulOptions(*builder, activation_function);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class Reshape
+ : public BuiltinOperator<TensorFlowReshapeOperator,
+ ::tflite::ReshapeOptions,
+ ::tflite::BuiltinOptions_ReshapeOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReshapeOptions(*builder,
+ builder->CreateVector(op.shape));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->shape.insert(op->shape.end(), options.new_shape()->begin(),
+ options.new_shape()->end());
+ }
+};
+
+class Softmax
+ : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
+ ::tflite::BuiltinOptions_SoftmaxOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->beta = options.beta();
+ }
+};
+
+class SpaceToDepth
+ : public BuiltinOperator<SpaceToDepthOperator,
+ ::tflite::SpaceToDepthOptions,
+ ::tflite::BuiltinOptions_SpaceToDepthOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->block_size = options.block_size();
+ }
+};
+
+class Split : public CustomOperator<TensorFlowSplitOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("num_split", op.num_split);
+ }
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->num_split = m["num_split"].AsInt64();
+ }
+};
+
+class TensorFlowUnsupported : public BaseOperator {
+ public:
+ using BaseOperator::BaseOperator;
+
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto fbb =
+ WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
+ if (fbb) {
+ return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
+ } else {
+ return Options::Custom(0);
+ }
+ }
+
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ if (custom_options) {
+ auto flexbuffer_map =
+ flexbuffers::GetRoot(custom_options->data(), custom_options->size())
+ .AsMap();
+ ReadOptions(flexbuffer_map, op.get());
+ }
+ return std::unique_ptr<Operator>(op.release());
+ }
+
+ std::unique_ptr<flexbuffers::Builder> WriteOptions(
+ const TensorFlowUnsupportedOperator& op) const {
+ auto fbb = absl::make_unique<flexbuffers::Builder>();
+
+ ::tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(op.tensorflow_node_def)) {
+ LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
+ return std::unique_ptr<flexbuffers::Builder>();
+ }
+
+ bool has_valid_attr = false;
+ size_t map_start = fbb->StartMap();
+ for (const auto& pair : node_def.attr()) {
+ const char* key = pair.first.c_str();
+ const auto& attr = pair.second;
+ switch (attr.value_case()) {
+ case ::tensorflow::AttrValue::kS:
+ fbb->String(key, attr.s());
+ has_valid_attr = true;
+ break;
+ case ::tensorflow::AttrValue::kI:
+ fbb->Int(key, attr.i());
+ has_valid_attr = true;
+ break;
+ case ::tensorflow::AttrValue::kF:
+ fbb->Float(key, attr.f());
+ has_valid_attr = true;
+ break;
+ case ::tensorflow::AttrValue::kB:
+ fbb->Bool(key, attr.b());
+ has_valid_attr = true;
+ break;
+ default:
+ LOG(WARNING) << "Ignoring unsupported attribute type with key '"
+ << key << "'";
+ break;
+ }
+ }
+ if (!has_valid_attr) {
+ return std::unique_ptr<flexbuffers::Builder>();
+ }
+ fbb->EndMap(map_start);
+ fbb->Finish();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+ }
+
+ void ReadOptions(const flexbuffers::Map& m,
+ TensorFlowUnsupportedOperator* op) const {
+ ::tensorflow::NodeDef node_def;
+ auto attr = node_def.mutable_attr();
+
+ const auto& keys = m.Keys();
+ for (size_t i = 0; i < keys.size(); ++i) {
+ const auto key = keys[i].AsKey();
+ const auto& value = m[key];
+ switch (value.GetType()) {
+ case flexbuffers::TYPE_STRING:
+ (*attr)[key].set_s(value.AsString().c_str());
+ break;
+ case flexbuffers::TYPE_INT:
+ (*attr)[key].set_i(value.AsInt64());
+ break;
+ case flexbuffers::TYPE_FLOAT:
+ (*attr)[key].set_f(value.AsFloat());
+ break;
+ case flexbuffers::TYPE_BOOL:
+ (*attr)[key].set_b(value.AsBool());
+ break;
+ default:
+ LOG(WARNING) << "Ignoring unsupported attribute type with key '"
+ << key << "'";
+ break;
+ }
+ }
+ node_def.SerializeToString(&op->tensorflow_node_def);
+ }
+};
+
+namespace {
+// Build a vector containing all the known operators.
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+ std::vector<std::unique_ptr<BaseOperator>> ops;
+
+ // Builtin Operators.
+ ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
+ ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
+ OperatorType::kAveragePool));
+ ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
+ OperatorType::kConcatenation));
+ ops.emplace_back(
+ new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv));
+ ops.emplace_back(
+ new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
+ OperatorType::kDepthwiseConv));
+ ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED,
+ OperatorType::kFullyConnected));
+ ops.emplace_back(
+ new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION,
+ OperatorType::kL2Normalization));
+ ops.emplace_back(
+ new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool));
+ ops.emplace_back(new LocalResponseNormalization(
+ ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ OperatorType::kLocalResponseNormalization));
+ ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D,
+ OperatorType::kMaxPool));
+ ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
+ ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
+ OperatorType::kTensorFlowReshape));
+ ops.emplace_back(
+ new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
+ ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
+ OperatorType::kSpaceToDepth));
+ ops.emplace_back(
+ new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
+
+ // Custom Operators.
+ ops.emplace_back(new Cast("CAST", OperatorType::kCast));
+ ops.emplace_back(
+ new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
+ ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
+ ops.emplace_back(new Split("SPLIT", OperatorType::kTensorFlowSplit));
+ ops.emplace_back(new TensorFlowUnsupported(
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
+
+ // There operators are supported by Toco, but not by TF Lite, and has no
+ // attributes.
+ ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
+ "RSQRT", OperatorType::kTensorFlowRsqrt));
+ ops.emplace_back(
+ new SimpleOperator<TensorFlowRsqrtOperator>("DIV", OperatorType::kDiv));
+
+ // Simple Operators.
+ ops.emplace_back(new SimpleOperator<DequantizeOperator>(
+ "DEQUANTIZE", OperatorType::kDequantize));
+ ops.emplace_back(
+ new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
+ ops.emplace_back(
+ new SimpleOperator<GatherOperator>("GATHER", OperatorType::kGather));
+ ops.emplace_back(
+ new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu));
+ ops.emplace_back(
+ new SimpleOperator<Relu1Operator>("RELU1", OperatorType::kRelu1));
+ ops.emplace_back(
+ new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
+ ops.emplace_back(new SimpleOperator<ResizeBilinearOperator>(
+ "RESIZE_BILINEAR", OperatorType::kResizeBilinear));
+ ops.emplace_back(new SimpleOperator<LogisticOperator>(
+ "LOGISTIC", OperatorType::kLogistic));
+ ops.emplace_back(
+ new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
+
+ return ops;
+}
+} // namespace
+
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+ std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
+
+ std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ for (auto& op : ops) {
+ result[op->type()] = std::move(op);
+ }
+
+ return result;
+}
+
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+ std::map<string, std::unique_ptr<BaseOperator>> result;
+
+ std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ for (auto& op : ops) {
+ result[op->name()] = std::move(op);
+ }
+
+ return result;
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
new file mode 100644
index 0000000000..37df302d46
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -0,0 +1,89 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+
+#include "flatbuffers/flatbuffers.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+class BaseOperator;
+
+// Return a map contained all knwo TF Lite Operators, keyed by their names.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+
+// Return a map contained all knwo TF Lite Operators, keyed by the type of
+// their tf.mini counterparts.
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+
+// These are the flatbuffer types for custom and builtin options.
+using CustomOptions = flatbuffers::Vector<uint8_t>;
+using BuiltinOptions = void;
+
+// A simple wrapper around the flatbuffer objects used to describe options that
+// configure operators.
+struct Options {
+ // Build custom options.
+ static Options Custom(flatbuffers::Offset<CustomOptions> offset) {
+ return {::tflite::BuiltinOptions_NONE, 0, offset};
+ }
+
+ // Build builtin options of the given type.
+ static Options Builtin(::tflite::BuiltinOptions type,
+ flatbuffers::Offset<BuiltinOptions> offset) {
+ return {type, offset, 0};
+ }
+
+ ::tflite::BuiltinOptions type;
+ flatbuffers::Offset<BuiltinOptions> builtin;
+ flatbuffers::Offset<CustomOptions> custom;
+};
+
+// A BaseOperator encapsulates the relationship between operators in tf.mini
+// and TF lite, and provides methods for converting between those two formats.
+class BaseOperator {
+ public:
+ // Build an operator with the given TF Lite name and tf.mini type.
+ BaseOperator(const string& name, OperatorType type)
+ : name_(name), type_(type) {}
+ virtual ~BaseOperator() = default;
+
+ string name() const { return name_; }
+ OperatorType type() const { return type_; }
+
+ // Given a tf.mini operator, create the corresponding flatbuffer options and
+ // return their offsets.
+ virtual Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const = 0;
+
+ // Read TF Lite options and create the appropriate tf.mini operator.
+ virtual std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const = 0;
+
+ private:
+ string name_;
+ OperatorType type_;
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
new file mode 100644
index 0000000000..543a9bd06c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -0,0 +1,370 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+class OperatorTest : public ::testing::Test {
+ protected:
+ // Return the operator for the given name and type.
+ const BaseOperator& GetOperator(const string& name, OperatorType type) {
+ using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>;
+ using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
+
+ static auto* by_name = new OpsByName(BuildOperatorByNameMap());
+ static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
+
+ // Make sure the two maps were consitently built.
+ CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
+ BaseOperator* op1 = by_name->at(name).get();
+ CHECK(op1->type() == type) << "while verifying '" << name << "'.";
+
+ CHECK(by_type->count(type))
+ << "No operator for '" << OperatorTypeName(type) << "'.";
+ BaseOperator* op2 = by_type->at(type).get();
+ CHECK(op2->name() == name)
+ << "while verifying '" << OperatorTypeName(type) << "'.";
+
+ return *op1;
+ }
+
+ // Use the given BaseOperator to serialize the tf.mini operator into a set of
+ // TF Lite options. Proceed to deserialize the options back into a new
+ // tf.mini operator, which is then returned. If `options` is given, it will
+ // be populated with the serialized options.
+ template <typename T>
+ std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
+ const T& toco_op,
+ Options* options = nullptr) {
+ flatbuffers::FlatBufferBuilder builder;
+ Options input_options = op.Serialize(toco_op, &builder);
+
+ if (options) {
+ *options = input_options;
+ }
+
+ builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
+ input_options.builtin, input_options.custom,
+ ::tflite::CustomOptionsFormat_FLEXBUFFERS));
+ auto* output_options =
+ flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
+ auto new_toco_op = op.Deserialize(output_options->builtin_options(),
+ output_options->custom_options());
+
+ CHECK(dynamic_cast<T*>(new_toco_op.get()))
+ << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to "
+ << HelpfulOperatorTypeName(toco_op);
+
+ return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
+ }
+
+ // Verify serialization and deserialization of simple operators (those
+ // that don't have any configuration parameters).
+ template <typename T>
+ void CheckSimpleOperator(const string& name, OperatorType type) {
+ Options options;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator(name, type), T(), &options);
+
+ ASSERT_EQ(0, options.builtin.o);
+ ASSERT_EQ(0, options.custom.o);
+ ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
+
+ ASSERT_NE(nullptr, output_toco_op.get());
+ }
+};
+
+TEST_F(OperatorTest, SimpleOperators) {
+ CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE",
+ OperatorType::kDequantize);
+ CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
+ CheckSimpleOperator<GatherOperator>("GATHER", OperatorType::kGather);
+ CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
+ CheckSimpleOperator<Relu1Operator>("RELU1", OperatorType::kRelu1);
+ CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
+ CheckSimpleOperator<ResizeBilinearOperator>("RESIZE_BILINEAR",
+ OperatorType::kResizeBilinear);
+ CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
+ CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
+}
+
+TEST_F(OperatorTest, BuiltinAdd) {
+ AddOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, CustomCast) {
+ CastOperator op;
+ op.src_data_type = ArrayDataType::kFloat;
+ op.dst_data_type = ArrayDataType::kUint8;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
+ EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
+ EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
+}
+
+TEST_F(OperatorTest, CustomConcatenation) {
+ ConcatenationOperator op;
+ op.concat_dim = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
+ EXPECT_EQ(op.concat_dim, output_toco_op->concat_dim);
+}
+
+TEST_F(OperatorTest, CustomDepthToSpace) {
+ DepthToSpaceOperator op;
+ op.block_size = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
+ EXPECT_EQ(op.block_size, output_toco_op->block_size);
+}
+
+TEST_F(OperatorTest, CustomFakeQuant) {
+ FakeQuantOperator op;
+ auto* minmax = new MinMax;
+ minmax->min = -10;
+ minmax->max = 200;
+ op.minmax.reset(minmax);
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
+ EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
+ EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
+}
+
+TEST_F(OperatorTest, CustomFullyConnected) {
+ FullyConnectedOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinL2Pool) {
+ L2PoolOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.kwidth = 480;
+ op.kheight = 1080;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
+ EXPECT_EQ(op.kheight, output_toco_op->kheight);
+}
+
+TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
+ LocalResponseNormalizationOperator op;
+ op.range = 123;
+ op.bias = 1.23;
+ op.alpha = 12.3;
+ op.beta = .123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("LOCAL_RESPONSE_NORMALIZATION",
+ OperatorType::kLocalResponseNormalization),
+ op);
+ EXPECT_EQ(op.range, output_toco_op->range);
+ EXPECT_EQ(op.bias, output_toco_op->bias);
+ EXPECT_EQ(op.alpha, output_toco_op->alpha);
+ EXPECT_EQ(op.beta, output_toco_op->beta);
+}
+
+TEST_F(OperatorTest, BuiltinMaxPool) {
+ MaxPoolOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.kwidth = 480;
+ op.kheight = 1080;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
+ EXPECT_EQ(op.kheight, output_toco_op->kheight);
+}
+
+TEST_F(OperatorTest, BuiltinReshape) {
+ TensorFlowReshapeOperator op;
+ op.shape = {1, 2, 4, 5, 8};
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op);
+ EXPECT_EQ(op.shape, output_toco_op->shape);
+}
+
+TEST_F(OperatorTest, CustomSoftmax) {
+ SoftmaxOperator op;
+ op.beta = 123.1;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
+ EXPECT_EQ(op.beta, output_toco_op->beta);
+}
+
+TEST_F(OperatorTest, BuiltinSpaceToDepth) {
+ SpaceToDepthOperator op;
+ op.block_size = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
+ EXPECT_EQ(op.block_size, output_toco_op->block_size);
+}
+
+TEST_F(OperatorTest, CustomSplit) {
+ TensorFlowSplitOperator op;
+ op.num_split = 123;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op);
+ EXPECT_EQ(op.num_split, output_toco_op->num_split);
+}
+
+TEST_F(OperatorTest, BuiltinAveragePool) {
+ AveragePoolOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.kwidth = 480;
+ op.kheight = 1080;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
+ EXPECT_EQ(op.kheight, output_toco_op->kheight);
+}
+
+TEST_F(OperatorTest, BuiltinConvolution) {
+ ConvOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
+ DepthwiseConvOperator op;
+ op.stride_width = 123;
+ op.stride_height = 124;
+ op.padding.type = PaddingType::kValid;
+ op.depth_multiplier = 6;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
+ EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
+ EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
+ EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
+ EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinL2Norm) {
+ L2NormalizationOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, BuiltinMul) {
+ MulOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu6;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, Svdf) {
+ SvdfOperator op;
+ op.fused_activation_function = FusedActivationFunctionType::kRelu;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
+ EXPECT_EQ(op.fused_activation_function,
+ output_toco_op->fused_activation_function);
+}
+
+TEST_F(OperatorTest, TensorFlowUnsupported) {
+ TensorFlowUnsupportedOperator op;
+ op.tensorflow_op = "MyCustomUnsupportedOp";
+
+ ::tensorflow::NodeDef node_def;
+ auto attr = node_def.mutable_attr();
+ (*attr)["float_attr"].set_f(2.0);
+ (*attr)["str_attr"].set_s("Hello World");
+ (*attr)["int_attr"].set_i(17);
+ (*attr)["bool_attr"].set_b(true);
+ node_def.SerializeToString(&op.tensorflow_node_def);
+
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
+ OperatorType::kTensorFlowUnsupported),
+ op);
+
+ ::tensorflow::NodeDef output_node_def;
+ output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
+ const auto& output_attr = output_node_def.attr();
+ EXPECT_EQ(2.0, output_attr.at("float_attr").f());
+ EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
+ EXPECT_EQ(17, output_attr.at("int_attr").i());
+ EXPECT_EQ(true, output_attr.at("bool_attr").b());
+}
+
+TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
+ TensorFlowUnsupportedOperator op;
+ op.tensorflow_op = "MyCustomUnsupportedOp";
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
+ OperatorType::kTensorFlowUnsupported),
+ op);
+
+ ::tensorflow::NodeDef output_node_def;
+ output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
+ EXPECT_TRUE(output_node_def.attr().empty());
+}
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h
new file mode 100644
index 0000000000..992b98baca
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+
+namespace toco {
+
+namespace tflite {
+
+// Simple operators don't have any configuration options and can be trivially
+// serialized and deserialized. Note that most of toco's operators will
+// likely be supported as builtin operators in TF Lite. Simple (and custom)
+// operators are mostly a convenience for the times when tf.mini supports more
+// operators than TF Lite.
+//
+// Template argument T must derive from ::toco::Operator.
+template <typename T>
+class SimpleOperator : public BaseOperator {
+ public:
+ using BaseOperator::BaseOperator;
+ Options Serialize(const Operator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return Options();
+ }
+ std::unique_ptr<Operator> Deserialize(
+ const BuiltinOptions* builtin_options,
+ const CustomOptions* custom_options) const override {
+ return std::unique_ptr<Operator>(new T);
+ }
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
new file mode 100644
index 0000000000..5b4dbfae24
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+namespace toco {
+
+namespace tflite {
+
+namespace {
+template <ArrayDataType T>
+DataBuffer::FlatBufferOffset CopyBuffer(
+ const Array& array, flatbuffers::FlatBufferBuilder* builder) {
+ using NativeT = ::toco::DataType<T>;
+ const auto& src_data = array.GetBuffer<T>().data;
+ const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
+ auto size = src_data.size() * sizeof(NativeT);
+ return builder->CreateVector(dst_data, size);
+}
+
+template <ArrayDataType T>
+void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
+ using NativeT = ::toco::DataType<T>;
+ auto* src_buffer = buffer.data();
+ const NativeT* src_data =
+ reinterpret_cast<const NativeT*>(src_buffer->data());
+ int num_items = src_buffer->size() / sizeof(NativeT);
+
+ std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
+ for (int i = 0; i < num_items; ++i) {
+ dst_data->push_back(*src_data);
+ ++src_data;
+ }
+}
+} // namespace
+
+::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
+ switch (array_data_type) {
+ case ArrayDataType::kFloat:
+ return ::tflite::TensorType_FLOAT32;
+ case ArrayDataType::kInt32:
+ return ::tflite::TensorType_INT32;
+ case ArrayDataType::kUint8:
+ return ::tflite::TensorType_UINT8;
+ default:
+ // FLOAT32 is filled for unknown data types.
+ // TODO(ycling): Implement type inference in TF Lite interpreter.
+ return ::tflite::TensorType_FLOAT32;
+ }
+}
+
+ArrayDataType DataType::Deserialize(int tensor_type) {
+ switch (::tflite::TensorType(tensor_type)) {
+ case ::tflite::TensorType_FLOAT32:
+ return ArrayDataType::kFloat;
+ case ::tflite::TensorType_INT32:
+ return ArrayDataType::kInt32;
+ case ::tflite::TensorType_UINT8:
+ return ArrayDataType::kUint8;
+ default:
+ LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
+ }
+}
+
+flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
+ const Array& array, flatbuffers::FlatBufferBuilder* builder) {
+ if (!array.buffer) return 0; // an empty buffer, usually an output.
+
+ switch (array.data_type) {
+ case ArrayDataType::kFloat:
+ return CopyBuffer<ArrayDataType::kFloat>(array, builder);
+ case ArrayDataType::kInt32:
+ return CopyBuffer<ArrayDataType::kInt32>(array, builder);
+ case ArrayDataType::kUint8:
+ return CopyBuffer<ArrayDataType::kUint8>(array, builder);
+ default:
+ LOG(FATAL) << "Unhandled array data type.";
+ }
+}
+
+void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
+ const ::tflite::Buffer& buffer, Array* array) {
+ if (tensor.buffer() == 0) return; // an empty buffer, usually an output.
+ if (buffer.data() == nullptr) return; // a non-defined buffer.
+
+ switch (tensor.type()) {
+ case ::tflite::TensorType_FLOAT32:
+ return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
+ case ::tflite::TensorType_INT32:
+ return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
+ case ::tflite::TensorType_UINT8:
+ return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
+ default:
+ LOG(FATAL) << "Unhandled tensor type.";
+ }
+}
+
+::tflite::Padding Padding::Serialize(PaddingType padding_type) {
+ switch (padding_type) {
+ case PaddingType::kSame:
+ return ::tflite::Padding_SAME;
+ case PaddingType::kValid:
+ return ::tflite::Padding_VALID;
+ default:
+ LOG(FATAL) << "Unhandled padding type.";
+ }
+}
+
+PaddingType Padding::Deserialize(int padding) {
+ switch (::tflite::Padding(padding)) {
+ case ::tflite::Padding_SAME:
+ return PaddingType::kSame;
+ case ::tflite::Padding_VALID:
+ return PaddingType::kValid;
+ default:
+ LOG(FATAL) << "Unhandled padding.";
+ }
+}
+
+::tflite::ActivationFunctionType ActivationFunction::Serialize(
+ FusedActivationFunctionType faf_type) {
+ switch (faf_type) {
+ case FusedActivationFunctionType::kNone:
+ return ::tflite::ActivationFunctionType_NONE;
+ case FusedActivationFunctionType::kRelu:
+ return ::tflite::ActivationFunctionType_RELU;
+ case FusedActivationFunctionType::kRelu6:
+ return ::tflite::ActivationFunctionType_RELU6;
+ case FusedActivationFunctionType::kRelu1:
+ return ::tflite::ActivationFunctionType_RELU1;
+ default:
+ LOG(FATAL) << "Unhandled fused activation function type.";
+ }
+}
+
+FusedActivationFunctionType ActivationFunction::Deserialize(
+ int activation_function) {
+ switch (::tflite::ActivationFunctionType(activation_function)) {
+ case ::tflite::ActivationFunctionType_NONE:
+ return FusedActivationFunctionType::kNone;
+ case ::tflite::ActivationFunctionType_RELU:
+ return FusedActivationFunctionType::kRelu;
+ case ::tflite::ActivationFunctionType_RELU6:
+ return FusedActivationFunctionType::kRelu6;
+ case ::tflite::ActivationFunctionType_RELU1:
+ return FusedActivationFunctionType::kRelu1;
+ default:
+ LOG(FATAL) << "Unhandled fused activation function type.";
+ }
+}
+
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/contrib/lite/toco/tflite/types.h
new file mode 100644
index 0000000000..f7c5140510
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/types.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+namespace tflite {
+
+struct DataType {
+ static ::tflite::TensorType Serialize(ArrayDataType array_data_type);
+ static ArrayDataType Deserialize(int tensor_type);
+};
+
+struct DataBuffer {
+ using FlatBufferOffset = flatbuffers::Offset<flatbuffers::Vector<uint8_t>>;
+
+ // Build the flatbuffer representation of a toco's Array and return the
+ // corresponding offset into the flatbuffer. Note that data from the array
+ // will be copied into the flatbuffer.
+ static FlatBufferOffset Serialize(const Array& array,
+ flatbuffers::FlatBufferBuilder* builder);
+ // Copy data from the given tensor into toco's Array.
+ static void Deserialize(const ::tflite::Tensor& tensor,
+ const ::tflite::Buffer& buffer, Array* array);
+};
+
+struct Padding {
+ static ::tflite::Padding Serialize(PaddingType padding_type);
+ static PaddingType Deserialize(int padding);
+};
+
+struct ActivationFunction {
+ static ::tflite::ActivationFunctionType Serialize(
+ FusedActivationFunctionType faf_type);
+ static FusedActivationFunctionType Deserialize(int activation_function);
+};
+
+} // namespace tflite
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
new file mode 100644
index 0000000000..174b78f3e6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -0,0 +1,191 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+
+namespace tflite {
+namespace {
+
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+// These are types that exist in TF Mini but don't have a correspondence
+// in TF Lite.
+static const ArrayDataType kUnsupportedTocoTypes[] = {
+ ArrayDataType::kNone, ArrayDataType::kBool, ArrayDataType::kInt64};
+
+// These are TF Lite types for which there is no correspondence in TF Mini.
+static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = {
+ ::tflite::TensorType_FLOAT16};
+
+// A little helper to match flatbuffer offsets.
+MATCHER_P(HasOffset, value, "") { return arg.o == value; }
+
+// Helper function that creates an array, writes it into a flatbuffer, and then
+// reads it back in.
+template <ArrayDataType T>
+Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType<T>> items) {
+ // NOTE: This test does not construct the full buffers list. Since
+ // Deserialize normally takes a buffer, we need to synthesize one and provide
+ // an index that is non-zero so the buffer is not assumed to be emtpy.
+ Array src;
+ src.data_type = T;
+ src.GetMutableBuffer<T>().data = items;
+
+ Array result;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T),
+ /*buffer*/ 1)); // Can't use 0 which means empty.
+ flatbuffers::FlatBufferBuilder buffer_builder;
+ Offset<Vector<uint8_t>> data_buffer =
+ DataBuffer::Serialize(src, &buffer_builder);
+ buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer));
+
+ auto* tensor =
+ flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
+ auto* buffer =
+ flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
+ DataBuffer::Deserialize(*tensor, *buffer, &result);
+ return result;
+}
+
+TEST(DataType, SupportedTypes) {
+ std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = {
+ {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
+ {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
+ {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}};
+ for (auto x : testdata) {
+ EXPECT_EQ(x.second, DataType::Serialize(x.first));
+ EXPECT_EQ(x.first, DataType::Deserialize(x.second));
+ }
+}
+
+TEST(DataType, UnsupportedTypes) {
+ for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
+ EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type.");
+ }
+
+ // Unsupported types are all serialized as FLOAT32 currently.
+ for (ArrayDataType t : kUnsupportedTocoTypes) {
+ EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t));
+ }
+}
+
+TEST(DataBuffer, EmptyBuffers) {
+ flatbuffers::FlatBufferBuilder builder;
+ Array array;
+ EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0));
+
+ builder.Finish(::tflite::CreateTensor(builder));
+ auto* tensor =
+ flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
+ flatbuffers::FlatBufferBuilder buffer_builder;
+ Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({});
+ buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
+ auto* buffer =
+ flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
+
+ DataBuffer::Deserialize(*tensor, *buffer, &array);
+ EXPECT_EQ(nullptr, array.buffer);
+}
+
+TEST(DataBuffer, UnsupportedTypes) {
+ for (ArrayDataType t : kUnsupportedTocoTypes) {
+ flatbuffers::FlatBufferBuilder builder;
+ Array array;
+ array.data_type = t;
+ array.GetMutableBuffer<ArrayDataType::kFloat>(); // This is OK.
+ EXPECT_DEATH(DataBuffer::Serialize(array, &builder),
+ "Unhandled array data type.");
+ }
+
+ for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1));
+ flatbuffers::FlatBufferBuilder buffer_builder;
+ Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({1});
+ buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
+ auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>(
+ buffer_builder.GetBufferPointer());
+ auto* tensor =
+ flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
+ Array array;
+ EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array),
+ "Unhandled tensor type.");
+ }
+}
+
+TEST(DataBuffer, Float) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kFloat>({1.0f, 2.0f});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kFloat>().data,
+ ::testing::ElementsAre(1.0f, 2.0f));
+}
+
+TEST(DataBuffer, Uint8) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint8>({127, 244});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint8>().data,
+ ::testing::ElementsAre(127, 244));
+}
+
+TEST(DataBuffer, Int32) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt32>({1, 1 << 30});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt32>().data,
+ ::testing::ElementsAre(1, 1 << 30));
+}
+
+TEST(Padding, All) {
+ EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
+ EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
+
+ EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid));
+ EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID));
+
+ EXPECT_DEATH(Padding::Serialize(static_cast<PaddingType>(10000)),
+ "Unhandled padding type.");
+ EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding.");
+}
+
+TEST(ActivationFunction, All) {
+ std::vector<
+ std::pair<FusedActivationFunctionType, ::tflite::ActivationFunctionType>>
+ testdata = {{FusedActivationFunctionType::kNone,
+ ::tflite::ActivationFunctionType_NONE},
+ {FusedActivationFunctionType::kRelu,
+ ::tflite::ActivationFunctionType_RELU},
+ {FusedActivationFunctionType::kRelu6,
+ ::tflite::ActivationFunctionType_RELU6},
+ {FusedActivationFunctionType::kRelu1,
+ ::tflite::ActivationFunctionType_RELU1}};
+ for (auto x : testdata) {
+ EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first));
+ EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second));
+ }
+
+ EXPECT_DEATH(ActivationFunction::Serialize(
+ static_cast<FusedActivationFunctionType>(10000)),
+ "Unhandled fused activation function type.");
+ EXPECT_DEATH(ActivationFunction::Deserialize(10000),
+ "Unhandled fused activation function type.");
+}
+
+} // namespace
+} // namespace tflite
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc
new file mode 100644
index 0000000000..f01ec0ec61
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdio>
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_tooling.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/platform/logging.h"
+
+#ifndef CHECK_OK
+#define CHECK_OK(val) CHECK_EQ((val).ok(), true)
+#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true)
+#endif
+
+namespace toco {
+namespace {
+
+#define QCHECK_REQUIRE_TOCO_FLAG(arg) \
+ QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg;
+
+void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ const TocoFlags& toco_flags) {
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet");
+
+ QCHECK_REQUIRE_TOCO_FLAG(input_file)
+ QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(),
+ port::file::Defaults()))
+ << "Specified input_file does not exist: "
+ << parsed_toco_flags.input_file.value();
+ QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(),
+ port::file::Defaults()))
+ << "Specified input_file exists, but is not readable: "
+ << parsed_toco_flags.input_file.value();
+
+ QCHECK_REQUIRE_TOCO_FLAG(output_file);
+ QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value()))
+ << "parsed_toco_flags.input_file.value() output_file is not writable: "
+ << parsed_toco_flags.output_file.value();
+}
+
+void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags) {
+ ModelFlags model_flags;
+ ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
+
+ TocoFlags toco_flags;
+ ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
+
+ CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags);
+
+ string input_file_contents;
+ CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(),
+ &input_file_contents,
+ port::file::Defaults()));
+ std::unique_ptr<Model> model =
+ Import(toco_flags, model_flags, input_file_contents);
+ Transform(toco_flags, model.get());
+ string output_file_contents;
+ Export(toco_flags, *model, toco_flags.allow_custom_ops(),
+ &output_file_contents);
+ CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(),
+ output_file_contents,
+ port::file::Defaults()));
+}
+
+} // namespace
+} // namespace toco
+
+int main(int argc, char** argv) {
+ toco::string msg;
+ toco::ParsedTocoFlags parsed_toco_flags;
+ toco::ParsedModelFlags parsed_model_flags;
+
+ // If no args were specified, give a help string to be helpful.
+ int* effective_argc = &argc;
+ char** effective_argv = argv;
+ if (argc == 1) {
+ // No arguments, so manufacture help argv.
+ static int dummy_argc = 2;
+ static char* dummy_argv[] = {argv[0], const_cast<char*>("--help")};
+ effective_argc = &dummy_argc;
+ effective_argv = dummy_argv;
+ }
+
+ // Parse toco flags and command flags in sequence, each one strips off args,
+ // giving InitGoogle a chance to handle all remaining arguments.
+ bool toco_success = toco::ParseTocoFlagsFromCommandLineFlags(
+ effective_argc, effective_argv, &msg, &parsed_toco_flags);
+ bool model_success = toco::ParseModelFlagsFromCommandLineFlags(
+ effective_argc, effective_argv, &msg, &parsed_model_flags);
+ if (!toco_success || !model_success || !msg.empty()) {
+ fprintf(stderr, "%s", msg.c_str());
+ fflush(stderr);
+ return 1;
+ }
+ toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true);
+ toco::ToolMain(parsed_toco_flags, parsed_model_flags);
+}
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
new file mode 100644
index 0000000000..d43c3b4a8e
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -0,0 +1,206 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/strip.h"
+#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace toco {
+
+bool ParseTocoFlagsFromCommandLineFlags(
+ int* argc, char* argv[], string* msg,
+ ParsedTocoFlags* parsed_toco_flags_ptr) {
+ using tensorflow::Flag;
+ ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
+ std::vector<tensorflow::Flag> flags = {
+ Flag("input_file", parsed_flags.input_file.bind(),
+ parsed_flags.input_file.default_value(),
+ "Input file (model of any supported format). For Protobuf "
+ "formats, both text and binary are supported regardless of file "
+ "extension."),
+ Flag("output_file", parsed_flags.output_file.bind(),
+ parsed_flags.output_file.default_value(),
+ "Output file. "
+ "For Protobuf formats, the binary format will be used."),
+ Flag("input_format", parsed_flags.input_format.bind(),
+ parsed_flags.input_format.default_value(),
+ "Input file format. One of: tensorflow_graphdef, "),
+ Flag("output_format", parsed_flags.output_format.bind(),
+ parsed_flags.output_format.default_value(), "Output file format."),
+ Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
+ parsed_flags.default_ranges_min.default_value(),
+ "If defined, will be used as the default value for the min bound "
+ "of min/max ranges used for quantization."),
+ Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
+ parsed_flags.default_ranges_max.default_value(),
+ "If defined, will be used as the default value for the max bound "
+ "of min/max ranges used for quantization."),
+ Flag("input_type", parsed_flags.input_type.bind(),
+ parsed_flags.input_type.default_value(),
+ "Data type of the input array in the "
+ "output file. "),
+ Flag("input_types", parsed_flags.input_types.bind(),
+ parsed_flags.input_types.default_value(),
+ "Data types of the input arrays in the "
+ "output file. "
+ "Comma-separated list matching the enumeration order of "
+ "input_arrays."),
+ Flag("inference_type", parsed_flags.inference_type.bind(),
+ parsed_flags.inference_type.default_value(),
+ "Data type, in the output file, of internal and output arrays "
+ "that are FLOAT in the input file. Thus, the value FLOAT means "
+ "keep doing floating-point inference, while the value "
+ "QUANTIZED_UINT8 means replace all internal floating-point "
+ "arithmetic by integer arithmetic producing 8-bit integer "
+ "activations instead of float activations --- which we call "
+ "\'quantized inference\'."),
+ Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
+ parsed_flags.drop_fake_quant.default_value(),
+ "Ignore and discard FakeQuant nodes. For instance, that can be used "
+ "to "
+ "generate plain float code without fake-quantization from a "
+ "quantized "
+ "graph."),
+ Flag(
+ "reorder_across_fake_quant",
+ parsed_flags.reorder_across_fake_quant.bind(),
+ parsed_flags.reorder_across_fake_quant.default_value(),
+ "Normally, FakeQuant nodes must be strict boundaries for graph "
+ "transformations, in order to ensure that quantized inference has "
+ "the "
+ "exact same arithmetic behavior as quantized training --- which is "
+ "the "
+ "whole point of quantized training and of FakeQuant nodes in the "
+ "first "
+ "place. However, that entails subtle requirements on where exactly "
+ "FakeQuant nodes must be placed in the graph. Some quantized graphs "
+ "have FakeQuant nodes at unexpected locations, that prevent graph "
+ "transformations that are necessary in order to generate inference "
+ "code for these graphs. Such graphs should be fixed, but as a "
+ "temporary work-around, setting this reorder_across_fake_quant flag "
+ "allows toco to perform necessary graph transformaitons on them, "
+ "at the cost of no longer faithfully matching inference and training "
+ "arithmetic."),
+ Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
+ parsed_flags.allow_custom_ops.default_value(),
+ "If true, allow TOCO to create TF Lite Custom operators for all the"
+ "unsupported Tensorflow ops."),
+ };
+ bool asked_for_help =
+ *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
+ if (asked_for_help) {
+ *msg += tensorflow::Flags::Usage(argv[0], flags);
+ return false;
+ } else {
+ return tensorflow::Flags::Parse(argc, argv, flags);
+ }
+}
+
+void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
+ TocoFlags* toco_flags) {
+ namespace port = toco::port;
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet");
+
+ enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified };
+
+#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \
+ do { \
+ if (requirement == FlagRequirement::kMustBeSpecified) { \
+ QCHECK(parsed_toco_flags.name.specified()) \
+ << "Missing required flag: " << #name; \
+ } \
+ if (requirement == FlagRequirement::kMustNotBeSpecified) { \
+ QCHECK(!parsed_toco_flags.name.specified()) \
+ << "Given other flags, this flag should not have been specified: " \
+ << #name; \
+ } \
+ } while (false)
+
+#define READ_TOCO_FLAG(name, requirement) \
+ ENFORCE_FLAG_REQUIREMENT(name, requirement); \
+ do { \
+ if (parsed_toco_flags.name.specified()) { \
+ toco_flags->set_##name(parsed_toco_flags.name.value()); \
+ } \
+ } while (false)
+
+#define PARSE_TOCO_FLAG(Type, name, requirement) \
+ ENFORCE_FLAG_REQUIREMENT(name, requirement); \
+ do { \
+ if (parsed_toco_flags.name.specified()) { \
+ Type x; \
+ QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \
+ << "Unrecognized " << #Type << " value " \
+ << parsed_toco_flags.name.value(); \
+ toco_flags->set_##name(x); \
+ } \
+ } while (false)
+
+ PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified);
+ PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified);
+ FlagRequirement tflite_flags_requirement =
+ toco_flags->output_format() == TFLITE
+ ? FlagRequirement::kMustBeSpecified
+ : FlagRequirement::kMustNotBeSpecified;
+ PARSE_TOCO_FLAG(IODataType, inference_type, tflite_flags_requirement);
+ READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
+ READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
+ READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
+ READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
+
+#undef READ_TOCO_FLAG
+#undef PARSE_TOCO_FLAG
+
+ const bool input_type_specified = parsed_toco_flags.input_type.specified();
+ const bool input_types_specified = parsed_toco_flags.input_types.specified();
+ if (toco_flags->output_format() == TFLITE) {
+ QCHECK(input_type_specified || input_types_specified)
+ << "When output_format=TFLITE, either input_type or input_types needs "
+ "to be specified.";
+ } else {
+ QCHECK(!input_type_specified && !input_types_specified)
+ << "With this output_format, neither input_type nor input_types must "
+ "be specified.";
+ }
+ QCHECK(!(input_type_specified && input_types_specified))
+ << "input_type and input_types are mutually exclusive";
+ if (input_type_specified) {
+ IODataType type;
+ QCHECK(IODataType_Parse(parsed_toco_flags.input_type.value(), &type))
+ << "Unrecognized input_type: " << parsed_toco_flags.input_type.value();
+ toco_flags->add_input_types(type);
+ }
+ if (input_types_specified) {
+ std::vector<string> input_types =
+ absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
+ for (const string& t : input_types) {
+ IODataType type;
+ QCHECK(IODataType_Parse(t, &type))
+ << "Unrecognized input_types value " << t
+ << " in input_types=" << parsed_toco_flags.input_types.value();
+ toco_flags->add_input_types(type);
+ }
+ }
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
new file mode 100644
index 0000000000..155a6fea87
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+
+namespace toco {
+// Parse and remove arguments handled from toco. Returns true if parsing
+// is successful. msg has the usage string if there was an error or
+// "--help" was specified
+bool ParseTocoFlagsFromCommandLineFlags(int* argc, char* argv[], string* msg,
+ ParsedTocoFlags* parsed_toco_flags_ptr);
+// Populate the TocoFlags proto with parsed_toco_flags data.
+void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
+ TocoFlags* toco_flags);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
new file mode 100644
index 0000000000..fd7c29fdc7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -0,0 +1,126 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+syntax = "proto2";
+package toco;
+
+// Supported I/O file formats. Some formats may be input-only or output-only.
+enum FileFormat {
+ FILE_FORMAT_UNKNOWN = 0;
+
+ // GraphDef, third_party/tensorflow/core/framework/graph.proto
+ TENSORFLOW_GRAPHDEF = 1;
+
+ // Tensorflow's mobile inference model.
+ // third_party/tensorflow/contrib/tflite/schema.fbs
+ TFLITE = 2;
+
+ // GraphViz
+ // Export-only.
+ GRAPHVIZ_DOT = 3;
+}
+
+// IODataType describes the numeric data types to be used by the output format.
+// See input_type and inference_type below.
+enum IODataType {
+ IO_DATA_TYPE_UNKNOWN = 0;
+
+ // Float32, not quantized
+ FLOAT = 1;
+
+ // Uint8, quantized
+ QUANTIZED_UINT8 = 2;
+
+ // Int32, not quantized
+ INT32 = 3;
+
+ // Int64, not quantized
+ INT64 = 4;
+
+ // String, not quantized
+ STRING = 5;
+}
+
+// TocoFlags encodes extra parameters that drive tooling operations, that
+// are not normally encoded in model files and in general may not be thought
+// of as properties of models, instead describing how models are to be
+// processed in the context of the present tooling job.
+// Next Id: 11
+message TocoFlags {
+ // Input file format
+ optional FileFormat input_format = 1;
+
+ // Output file format
+ optional FileFormat output_format = 2;
+
+ // Numeric data types of the input arrays in the output format.
+ // This controls what input types the output file will be expecting.
+ // This is not a description of the input types of the input file.
+ // For example, the input file may have a float input placeholder,
+ // but we may want to generate a quantized TFLite file from it,
+ // or a float TFLite file taking a quantized input.
+ //
+ // The length of this list should match the length of the input_arrays
+ // list in ModelFlags.
+ repeated IODataType input_types = 9;
+
+ // Numeric data type of the internal activations array and output array.
+ //
+ // As a matter of implementation detail, most model
+ // parameter arrays (weights, etc) will tend to also use this data type.
+ // Not all will, though: for instance, bias vectors will typically
+ // get quantized as int32 when weights and activations get quantized as uint8.
+ optional IODataType inference_type = 4;
+
+ // default_ranges_min and default_ranges_max are helpers to experiment
+ // with quantization of models. Normally, quantization requires the input
+ // model to have (min, max) range information for every activations array.
+ // This is needed in order to know how to quantize arrays and still achieve
+ // satisfactory accuracy. However, in some circumstances one would just like
+ // to estimate the performance of quantized inference, without caring about
+ // accuracy. That is what default_ranges_min and default_ranges_max are for:
+ // when specified, they will be used as default (min, max) range boundaries
+ // for all activation arrays that lack (min, max) range information, thus
+ // allowing for quantization to proceed.
+ //
+ // It should be clear from the above explanation that these parameters are
+ // for experimentation purposes only and should not be used in production:
+ // they make it easy to quantize models, but the resulting quantized model
+ // will be inaccurate.
+ optional float default_ranges_min = 5;
+ optional float default_ranges_max = 6;
+
+ // Ignore and discard FakeQuant nodes. For instance, that can be used to
+ // generate plain float code without fake-quantization from a quantized
+ // graph.
+ optional bool drop_fake_quant = 7;
+
+ // Normally, FakeQuant nodes must be strict boundaries for graph
+ // transformations, in order to ensure that quantized inference has the
+ // exact same arithmetic behavior as quantized training --- which is the
+ // whole point of quantized training and of FakeQuant nodes in the first
+ // place. However, that entails subtle requirements on where exactly
+ // FakeQuant nodes must be placed in the graph. Some quantized graphs
+ // have FakeQuant nodes at unexpected locations, that prevent graph
+ // transformations that are necessary in order to generate inference
+ // code for these graphs. Such graphs should be fixed, but as a
+ // temporary work-around, setting this reorder_across_fake_quant flag
+ // allows toco to perform necessary graph transformaitons on them,
+ // at the cost of no longer faithfully matching inference and training
+ // arithmetic.
+ optional bool reorder_across_fake_quant = 8;
+
+ // If true, allow TOCO to create TF Lite Custom operators for all the
+ // unsupported Tensorflow ops.
+ optional bool allow_custom_ops = 10;
+}
diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc
new file mode 100644
index 0000000000..4e98e7081d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc
@@ -0,0 +1,22 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+
+namespace toco {
+GraphVizDumpOptions* GraphVizDumpOptions::singleton() {
+ static auto* ptr = new GraphVizDumpOptions;
+ return ptr;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
new file mode 100644
index 0000000000..ae0541f62b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+
+#include <string>
+
+namespace toco {
+
+// Global data for determining whether to output graph viz format from toco.
+struct GraphVizDumpOptions {
+ std::string graphviz_first_array;
+ std::string graphviz_last_array;
+ std::string dump_graphviz;
+ bool dump_graphviz_video = false;
+
+ static GraphVizDumpOptions* singleton();
+};
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc
new file mode 100644
index 0000000000..a1c8696cd0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_port.cc
@@ -0,0 +1,227 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstring>
+
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace port {
+void CopyToBuffer(const string& src, char* dest) {
+ memcpy(dest, src.data(), src.size());
+}
+
+#ifdef PLATFORM_GOOGLE
+void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); }
+#endif
+} // namespace port
+} // namespace toco
+
+#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__)
+
+// Wrap Google file operations.
+
+#include "base/init_google.h"
+#include "file/base/file.h"
+#include "file/base/filesystem.h"
+#include "file/base/helpers.h"
+#include "file/base/options.h"
+#include "file/base/path.h"
+
+namespace toco {
+namespace port {
+
+void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) {
+ ::InitGoogle(usage, argc, argv, remove_flags);
+}
+
+void CheckInitGoogleIsDone(const char* message) {
+ ::CheckInitGoogleIsDone(message);
+}
+
+namespace file {
+
+// Conversion to our wrapper Status.
+Status ToStatus(const ::util::Status& uts) {
+ return Status(uts.ok(), uts.error_message());
+}
+
+// Conversion to our wrapper Options.
+toco::port::file::Options ToOptions(const ::file::Options& options) {
+ CHECK_EQ(&options, &::file::Defaults());
+ return Options();
+}
+
+Status Writable(const string& filename) {
+ File* f = nullptr;
+ const auto status = ::file::Open(filename, "w", &f, ::file::Defaults());
+ if (f) {
+ QCHECK_OK(f->Close(::file::Defaults()));
+ }
+ return ToStatus(status);
+}
+
+Status Readable(const string& filename, const file::Options& options) {
+ return ToStatus(::file::Readable(filename, ::file::Defaults()));
+}
+
+Status Exists(const string& filename, const file::Options& options) {
+ auto status = ::file::Exists(filename, ::file::Defaults());
+ return ToStatus(status);
+}
+
+Status GetContents(const string& filename, string* contents,
+ const file::Options& options) {
+ return ToStatus(::file::GetContents(filename, contents, ::file::Defaults()));
+}
+
+Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
+ return ToStatus(::file::SetContents(filename, contents, ::file::Defaults()));
+}
+
+string JoinPath(const string& a, const string& b) {
+ return ::file::JoinPath(a, b);
+}
+
+} // namespace file
+} // namespace port
+} // namespace toco
+
+#else // (__APPLE__ || __ANDROID__)
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <cstdio>
+
+#if defined(PLATFORM_GOOGLE)
+#include "base/commandlineflags.h"
+#endif
+
+namespace toco {
+namespace port {
+
+static bool port_initialized = false;
+
+void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) {
+ if (!port_initialized) {
+#if defined(PLATFORM_GOOGLE)
+ ParseCommandLineFlags(argc, argv, remove_flags);
+#endif
+ port_initialized = true;
+ }
+}
+
+void CheckInitGoogleIsDone(const char* message) {
+ CHECK(port_initialized) << message;
+}
+
+namespace file {
+
+Status Writable(const string& filename) {
+ FILE* f = fopen(filename.c_str(), "w");
+ if (f) {
+ fclose(f);
+ return Status(true, "");
+ }
+ return Status(false, "not writable");
+}
+
+Status Readable(const string& filename, const file::Options& options) {
+ FILE* f = fopen(filename.c_str(), "r");
+ if (f) {
+ fclose(f);
+ return Status(true, "");
+ }
+ return Status(false, "not readable");
+}
+
+Status Exists(const string& filename, const file::Options& options) {
+ struct stat statbuf;
+ int ret = stat(filename.c_str(), &statbuf);
+ return Status(ret != -1, "");
+}
+
+Status GetContents(const string& path, string* output,
+ const file::Options& options) {
+ output->clear();
+
+ int fd = open(path.c_str(), O_RDONLY);
+ if (fd == -1) {
+ return Status(false, "can't open() for read");
+ }
+
+ // Direct read, for speed.
+ const int kBufSize = 1 << 16;
+ char buffer[kBufSize];
+ while (true) {
+ int size = read(fd, buffer, kBufSize);
+ if (size == 0) {
+ // Done.
+ close(fd);
+ return Status(true, "");
+ } else if (size == -1) {
+ // Error.
+ close(fd);
+ return Status(false, "error during read()");
+ } else {
+ output->append(buffer, size);
+ }
+ }
+
+ CHECK(0);
+ return Status(false, "internal error");
+}
+
+Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
+ int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664);
+ if (fd == -1) {
+ return Status(false, "can't open() for write");
+ }
+
+ size_t i = 0;
+ while (i < contents.size()) {
+ size_t to_write = contents.size() - i;
+ ssize_t written = write(fd, &contents[i], to_write);
+ if (written == -1) {
+ close(fd);
+ return Status(false, "write() error");
+ }
+ i += written;
+ }
+ close(fd);
+
+ return Status(true, "");
+}
+
+string JoinPath(const string& base, const string& filename) {
+ if (base.empty()) return filename;
+ string base_fixed = base;
+ if (!base_fixed.empty() && base_fixed.back() == '/') base_fixed.pop_back();
+ string filename_fixed = filename;
+ if (!filename_fixed.empty() && filename_fixed.front() == '/')
+ filename_fixed.erase(0, 1);
+ return base_fixed + "/" + filename_fixed;
+}
+
+} // namespace file
+} // namespace port
+} // namespace toco
+
+#endif // (__APPLE || __ANDROID__)
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
new file mode 100644
index 0000000000..b5cb7a11e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+
+// Portability layer for toco tool. Mainly, abstract filesystem access so we
+// can build and use on google internal environments and on OSX.
+
+#include <string>
+#include "tensorflow/contrib/lite/toco/format_port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/platform.h"
+#if defined(PLATFORM_GOOGLE)
+#include "absl/strings/cord.h"
+#endif // PLATFORM_GOOGLE
+
+#ifdef PLATFORM_GOOGLE
+#define TFLITE_PROTO_NS proto2
+#else
+#define TFLITE_PROTO_NS google::protobuf
+#endif
+
+namespace toco {
+namespace port {
+
+class Status {
+ public:
+ Status() {}
+
+ Status(bool ok, const string& message) : ok_(ok), message_(message) {}
+
+ bool ok() const { return ok_; }
+
+ const string error_message() const { return message_; }
+
+ private:
+ bool ok_ = false;
+ string message_;
+};
+
+void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags);
+void CheckInitGoogleIsDone(const char* message);
+
+namespace file {
+class Options {};
+inline Options Defaults() {
+ Options o;
+ return o;
+}
+Status GetContents(const string& filename, string* contents,
+ const Options& options);
+Status SetContents(const string& filename, const string& contents,
+ const Options& options);
+string JoinPath(const string& base, const string& filename);
+Status Writable(const string& filename);
+Status Readable(const string& filename, const Options& options);
+Status Exists(const string& filename, const Options& options);
+} // namespace file
+
+// Copy `src` string to `dest`. User must ensure `dest` has enough space.
+#if defined(PLATFORM_GOOGLE)
+void CopyToBuffer(const ::Cord& src, char* dest);
+#endif // PLATFORM_GOOGLE
+void CopyToBuffer(const string& src, char* dest);
+} // namespace port
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/toco_port_test.cc b/tensorflow/contrib/lite/toco/toco_port_test.cc
new file mode 100644
index 0000000000..650a617aeb
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_port_test.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+namespace port {
+namespace {
+
+#ifdef PLATFORM_GOOGLE
+#define TFLITE_PREFIX "third_party/tensorflow/contrib/lite/"
+#else
+#define TFLITE_PREFIX "tensorflow/contrib/lite/"
+#endif
+
+TEST(TocoPortTest, Exists) {
+ EXPECT_TRUE(
+ file::Exists(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults())
+ .ok());
+
+ EXPECT_FALSE(
+ file::Exists("non-existent_file_asldjflasdjf", file::Defaults()).ok());
+}
+
+TEST(TocoPortTest, Readable) {
+ EXPECT_TRUE(
+ file::Readable(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults())
+ .ok());
+
+ EXPECT_FALSE(
+ file::Readable("non-existent_file_asldjflasdjf", file::Defaults()).ok());
+}
+
+TEST(TocoPortTest, JoinPath) {
+ EXPECT_EQ("part1/part2", file::JoinPath("part1", "part2"));
+ EXPECT_EQ("part1/part2", file::JoinPath("part1/", "part2"));
+ EXPECT_EQ("part1/part2", file::JoinPath("part1", "/part2"));
+ EXPECT_EQ("part1/part2", file::JoinPath("part1/", "/part2"));
+}
+
+} // namespace
+} // namespace port
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
new file mode 100644
index 0000000000..232538a841
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/toco_tooling.h"
+
+#include <cstdlib>
+#include <memory>
+#include <set>
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h"
+#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+#include "tensorflow/contrib/lite/toco/export_tensorflow.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/import_tensorflow.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tflite/export.h"
+#include "tensorflow/contrib/lite/toco/tflite/import.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace {
+// CHECK-fails if the model contains a kTensorFlowUnsupported operation.
+void CheckUnsupportedOperations(const Model& model) {
+ std::set<string> unsupported_ops;
+ for (auto& op : model.operators) {
+ if (op->type == OperatorType::kTensorFlowUnsupported) {
+ unsupported_ops.insert(
+ static_cast<const TensorFlowUnsupportedOperator*>(op.get())
+ ->tensorflow_op);
+ }
+ }
+ QCHECK(unsupported_ops.empty())
+ << "These unsupported ops were not removed by graph transformations: "
+ << absl::StrJoin(unsupported_ops, ", ");
+}
+
+void MakeGeneralGraphTransformationsSet(
+ GraphTransformationsSet* transformations) {
+ CHECK(transformations->empty());
+ transformations->Add(new ResolveReshapeAttributes);
+ transformations->Add(new PropagateArrayDataTypes);
+ transformations->Add(new PropagateFixedSizes);
+ transformations->Add(new RemoveTensorFlowAssert);
+ transformations->Add(new RemoveTensorFlowIdentity);
+ transformations->Add(new RemoveTrivialConcatenation);
+ transformations->Add(new RemoveTrivialConcatenationInput);
+ transformations->Add(new RemoveUnusedOp);
+ transformations->Add(new EnsureBiasVectors);
+ transformations->Add(new ResolveReorderAxes);
+ transformations->Add(new ResolveTensorFlowMatMul);
+ transformations->Add(new FuseBinaryIntoPrecedingAffine);
+ transformations->Add(new FuseBinaryIntoFollowingAffine);
+ transformations->Add(new ResolveBatchNormalization);
+ transformations->Add(new ResolveConstantBinaryOperator);
+ transformations->Add(new ResolveConstantUnaryOperator);
+ transformations->Add(new ResolveTensorFlowMerge);
+ transformations->Add(new ResolveTensorFlowSqueeze);
+ transformations->Add(new ResolveTensorFlowSwitch);
+ transformations->Add(new ResolveTensorFlowTile);
+ transformations->Add(new ResolveTensorFlowConcat);
+ transformations->Add(new IdentifyL2Normalization);
+ transformations->Add(new IdentifyL2Pool);
+ transformations->Add(new IdentifyRelu1);
+ transformations->Add(new RemoveTrivialBinaryOperator);
+ transformations->Add(new ReadFakeQuantMinMax);
+ transformations->Add(new ResolvePadAttributes);
+ transformations->Add(new ResolveStridedSliceAttributes);
+ transformations->Add(new ResolveSliceAttributes);
+ transformations->Add(new ResolveMeanAttributes);
+ transformations->Add(new ResolveConstantTensorFlowShape);
+ transformations->Add(new MakeInitialDequantizeOperator);
+}
+
+void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) {
+ const bool output_is_tflite = toco_flags.output_format() == TFLITE;
+
+ if (output_is_tflite) {
+ if (!toco_flags.input_types().empty()) {
+ for (int i = 0; i < model->flags.input_arrays_size(); i++) {
+ int input_types_index = toco_flags.input_types_size() == 1 ? 0 : i;
+ const auto input_type = toco_flags.input_types(input_types_index);
+ ArrayDataType final_data_type = ArrayDataType::kNone;
+ switch (input_type) {
+ case FLOAT:
+ final_data_type = ArrayDataType::kFloat;
+ break;
+ case QUANTIZED_UINT8:
+ final_data_type = ArrayDataType::kUint8;
+ break;
+ case INT32:
+ final_data_type = ArrayDataType::kInt32;
+ break;
+ case INT64:
+ final_data_type = ArrayDataType::kInt64;
+ break;
+ default:
+ LOG(FATAL) << "Unknown data type";
+ }
+ model->arrays[model->flags.input_arrays(i).name()]->final_data_type =
+ final_data_type;
+ }
+ }
+ } else {
+ for (int i = 0; i < model->flags.input_arrays_size(); i++) {
+ model->arrays[model->flags.input_arrays(i).name()]->final_data_type =
+ ArrayDataType::kFloat;
+ }
+ }
+}
+
+} // namespace
+
+std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
+ const ModelFlags& model_flags,
+ const string& input_file_contents) {
+ std::unique_ptr<Model> model;
+ switch (toco_flags.input_format()) {
+ case TENSORFLOW_GRAPHDEF:
+ model = ImportTensorFlowGraphDef(model_flags, input_file_contents);
+ break;
+ case TFLITE:
+ model = toco::tflite::Import(model_flags, input_file_contents);
+ ResolveModelFlags(model_flags, model.get());
+ CheckInvariants(*model);
+ break;
+ default:
+ LOG(FATAL) << "Unhandled input_format";
+ }
+
+ LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
+
+ return model;
+}
+
+void Transform(const TocoFlags& toco_flags, Model* model) {
+ const FileFormat output_format = toco_flags.output_format();
+ const IODataType inference_type = toco_flags.inference_type();
+
+ const bool output_is_tflite = output_format == TFLITE;
+
+ const bool output_is_tflite_quantized =
+ output_is_tflite && inference_type == QUANTIZED_UINT8;
+
+ if (output_is_tflite) {
+ QCHECK(toco_flags.input_types_size() == 1 ||
+ toco_flags.input_types_size() == model->flags.input_arrays_size())
+ << "Mismatched numbers of input_arrays and input_types";
+ }
+
+ if (output_is_tflite_quantized) {
+ for (const auto& input_type : toco_flags.input_types()) {
+ QCHECK_NE(input_type, FLOAT)
+ << "Quantized inference is not allowed with float inputs.";
+ }
+ }
+
+ SetArrayFinalDataTypes(toco_flags, model);
+
+ GraphTransformationsSet transformations;
+ MakeGeneralGraphTransformationsSet(&transformations);
+ auto* remove_trivial_reshape = new RemoveTrivialReshape;
+ transformations.Add(remove_trivial_reshape);
+ if (output_format == TFLITE) {
+ transformations.Add(new FuseActivationFunctions);
+ } else {
+ transformations.Add(new UnfuseActivationFunctions);
+ }
+ if (output_format != TENSORFLOW_GRAPHDEF) {
+ transformations.Add(new ResolveConstantFakeQuant);
+ }
+ if (toco_flags.drop_fake_quant()) {
+ transformations.Add(new DropFakeQuant);
+ } else {
+ // See the doc for --reorder_across_fake_quant: that flag is needed to
+ // support some existing models, e.g. WordLens, that have FakeQuant
+ // nodes in the wrong places.
+ // We currently unconditionally enable that behavior when the output
+ // format is DarwiNN because the DarwiNN test code does not make it
+ // easy to pass a new toco flag. Once that is resolved on the DarwiNN
+ // tests side, the special-casing of DarwiNN here can go away.
+ // TODO(benoitjacob): so drop it when we can.
+ if ((output_is_tflite_quantized &&
+ toco_flags.reorder_across_fake_quant())) {
+ transformations.Add(new DropFakeQuant);
+ }
+ }
+ transformations.Add(new ConvertPureConvToDepthwise);
+ // TFLite export does not yet support fused LSTM cell.
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ transformations.Add(new IdentifyLstmCell);
+ }
+ transformations.Add(new ResolveConstantConcatenation);
+ RunGraphTransformations(model, "general graph transformations",
+ transformations);
+ if (output_is_tflite_quantized) {
+ RunGraphTransformations(model, "pre-quantization graph transformations",
+ {new HardcodeMinMax, new DropFakeQuant});
+ }
+
+ if (output_is_tflite_quantized) {
+ if (toco_flags.has_default_ranges_min() &&
+ toco_flags.has_default_ranges_max()) {
+ UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
+ toco_flags.default_ranges_max());
+ }
+ CheckIsReadyForQuantization(*model);
+ RunGraphTransformations(
+ model, "quantization graph transformations",
+ {new Quantize, new RemoveTrivialQuantizedActivationFunc,
+ new RemoveFinalDequantizeOp});
+ } else {
+ GraphTransformationsSet dequantization_transformations{new Dequantize};
+ // Dequantize creates FakeQuant nodes. We may want to discard
+ // those immediately.
+ if (toco_flags.drop_fake_quant()) {
+ dequantization_transformations.Add(new DropFakeQuant);
+ }
+
+ RunGraphTransformations(model, "dequantization graph transformations",
+ dequantization_transformations);
+ }
+
+ LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
+
+ if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
+ // By now there shouldn't be any unsupported ops when exporting to
+ // TensorFlow GraphDef.
+ CheckUnsupportedOperations(*model);
+ }
+
+ if (output_is_tflite) {
+ AllocateTransientArrays(model, kDefaultTransientDataAlignment);
+ LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
+ }
+
+ CheckModelCounts(*model);
+ CheckFinalDataTypesSatisfied(*model);
+
+ int64 ops_count;
+ if (EstimateArithmeticOpsCount(*model, &ops_count)) {
+ LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
+ << " billion (note that a multiply-add is counted as 2 ops).";
+ }
+}
+
+void Export(const TocoFlags& toco_flags, const Model& model,
+ bool allow_custom_ops, string* output_file_contents) {
+ switch (toco_flags.output_format()) {
+ case TENSORFLOW_GRAPHDEF:
+ ExportTensorFlowGraphDef(model, output_file_contents);
+ break;
+ case TFLITE:
+ toco::tflite::Export(model, allow_custom_ops, output_file_contents);
+ break;
+ case GRAPHVIZ_DOT:
+ DumpGraphviz(model, output_file_contents);
+ break;
+ default:
+ LOG(FATAL) << "Unhandled output_format";
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/contrib/lite/toco/toco_tooling.h
new file mode 100644
index 0000000000..9c5a93a211
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_tooling.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+
+namespace toco {
+
+// Imports the input file into a Model object.
+std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
+ const ModelFlags& model_flags,
+ const string& input_file_contents);
+
+// Transforms a Model. The resulting Model is ready to be passed
+// to Export with the exact same toco_flags.
+void Transform(const TocoFlags& toco_flags, Model* model);
+
+// Exports the Model, which must be of the 'lowered' form returned by
+// Transform, to a file of the format given by
+// toco_flags.output_format().
+void Export(const TocoFlags& toco_flags, const Model& model,
+ bool allow_custom_ops, string* output_file_contents);
+
+// This if for backward-compatibility with internal tools.
+inline void Export(const TocoFlags& toco_flags, const Model& model,
+ string* output_file_contents) {
+ Export(toco_flags, model, true, output_file_contents);
+}
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h
new file mode 100644
index 0000000000..ad42497ada
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_types.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+
+#include <string>
+#include "tensorflow/core/platform/platform.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES)
+#include "tensorflow/core/platform/google/integral_types.h"
+#else
+#include "tensorflow/core/platform/default/integral_types.h"
+#endif
+
+namespace toco {
+#ifdef PLATFORM_GOOGLE
+using ::string;
+#else
+using std::string;
+#endif
+
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
new file mode 100644
index 0000000000..bcbfed62d3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -0,0 +1,1552 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+#include <functional>
+#include <iterator>
+#include <set>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/platform/logging.h"
+
+
+namespace toco {
+
+string LogName(const Operator& op) {
+ const string& opname = HelpfulOperatorTypeName(op);
+ if (op.outputs.empty()) {
+ return toco::port::StringF("{%s operator}", opname);
+ } else {
+ return toco::port::StringF("{%s operator with output %s}", opname,
+ op.outputs[0]);
+ }
+}
+
+bool IsInputArray(const Model& model, const string& name) {
+ for (const auto& input_array : model.flags.input_arrays()) {
+ if (input_array.name() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsArrayConsumed(const Model& model, const string& name) {
+ if (GetOpWithInput(model, name)) {
+ return true;
+ }
+ for (const string& model_output : model.flags.output_arrays()) {
+ if (model_output == name) {
+ return true;
+ }
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (rnn_state.back_edge_source_array() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+int CountTrueOutputs(const Model& model, const Operator& op) {
+ int count = 0;
+ for (const string& output : op.outputs) {
+ if (IsArrayConsumed(model, output)) {
+ ++count;
+ }
+ }
+ return count;
+}
+
+int CountOpsWithInput(const Model& model, const string& array_name) {
+ int count = 0;
+ for (const auto& op : model.operators) {
+ for (auto& input : op->inputs) {
+ if (input == array_name) {
+ count++;
+ }
+ }
+ }
+ return count;
+}
+
+bool DeleteArrayIfUnused(const string& array_name, Model* model) {
+ if (CountOpsWithInput(*model, array_name) == 0) {
+ model->arrays.erase(array_name);
+ return true;
+ }
+ return false;
+}
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
+ const Model& model, const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ for (auto& output : it->get()->outputs) {
+ if (output == array_name) {
+ return it;
+ }
+ }
+ }
+ return model.operators.end();
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
+ Model& model, const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ for (auto& output : it->get()->outputs) {
+ if (output == array_name) {
+ return it;
+ }
+ }
+ }
+ return model.operators.end();
+}
+
+Operator* GetOpWithOutput(const Model& model, const string& array_name) {
+ auto it = FindOpWithOutput(model, array_name);
+ return it == model.operators.end() ? nullptr : it->get();
+}
+
+// GetFirstOpWithInput assumes that this finds the first op.
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
+ const Model& model, const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ for (auto& input : it->get()->inputs) {
+ if (input == array_name) {
+ return it;
+ }
+ }
+ }
+ return model.operators.end();
+}
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
+ const Model& model, const Operator* op) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ if (it->get() == op) {
+ return it;
+ }
+ }
+ return model.operators.end();
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
+ const Operator* op) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ if (it->get() == op) {
+ return it;
+ }
+ }
+ return model.operators.end();
+}
+
+Operator* GetOpWithInput(const Model& model, const string& array_name) {
+ auto it = FindOpWithInput(model, array_name);
+ return it == model.operators.end() ? nullptr : it->get();
+}
+
+Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
+ auto it = FindOpWithInput(model, array_name);
+ return it == model.operators.end() ? nullptr : it->get();
+}
+
+string FormatArraysList(const Model& model, const std::vector<string>& list) {
+ if (list.empty()) {
+ return "[]";
+ }
+ string result = "";
+ if (list.size() > 1) {
+ result += "[ ";
+ }
+ for (std::size_t i = 0; i < list.size(); i++) {
+ if (i > 0) {
+ result += ", ";
+ }
+ result += list[i];
+ }
+ if (list.size() > 1) {
+ result += " ]";
+ }
+ return result;
+}
+
+const char* OperatorTypeName(OperatorType type) {
+ switch (type) {
+#define HANDLE_OPERATORTYPENAME_CASE(c) \
+ case OperatorType::k##c: \
+ return #c;
+ HANDLE_OPERATORTYPENAME_CASE(Add)
+ HANDLE_OPERATORTYPENAME_CASE(AveragePool)
+ HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
+ HANDLE_OPERATORTYPENAME_CASE(Conv)
+ HANDLE_OPERATORTYPENAME_CASE(Concatenation)
+ HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
+ HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
+ HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
+ HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
+ HANDLE_OPERATORTYPENAME_CASE(Dequantize)
+ HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
+ HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
+ HANDLE_OPERATORTYPENAME_CASE(Logistic)
+ HANDLE_OPERATORTYPENAME_CASE(LstmCell)
+ HANDLE_OPERATORTYPENAME_CASE(MaxPool)
+ HANDLE_OPERATORTYPENAME_CASE(L2Pool)
+ HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
+ HANDLE_OPERATORTYPENAME_CASE(Mul)
+ HANDLE_OPERATORTYPENAME_CASE(Relu)
+ HANDLE_OPERATORTYPENAME_CASE(Relu1)
+ HANDLE_OPERATORTYPENAME_CASE(Relu6)
+ HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
+ HANDLE_OPERATORTYPENAME_CASE(Softmax)
+ HANDLE_OPERATORTYPENAME_CASE(Div)
+ HANDLE_OPERATORTYPENAME_CASE(Tanh)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
+ HANDLE_OPERATORTYPENAME_CASE(Pad)
+ HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape)
+ HANDLE_OPERATORTYPENAME_CASE(Squeeze)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape)
+ HANDLE_OPERATORTYPENAME_CASE(Slice)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch)
+ HANDLE_OPERATORTYPENAME_CASE(Sub)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2)
+ HANDLE_OPERATORTYPENAME_CASE(Cast)
+ HANDLE_OPERATORTYPENAME_CASE(Floor)
+ HANDLE_OPERATORTYPENAME_CASE(Gather)
+ HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
+ HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
+ HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
+ HANDLE_OPERATORTYPENAME_CASE(Mean)
+ HANDLE_OPERATORTYPENAME_CASE(Svdf)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
+ default:
+ LOG(FATAL) << "Unhandled op type";
+#undef HANDLE_OPERATORTYPENAME_CASE
+ }
+}
+
+string HelpfulOperatorTypeName(const Operator& op) {
+ if (op.type == OperatorType::kTensorFlowUnsupported) {
+ return toco::port::StringF(
+ "(Unsupported TensorFlow op: %s)",
+ static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
+ }
+ return OperatorTypeName(op.type);
+}
+
+void LogSummary(int log_level, const Model& model) {
+ VLOG(log_level) << "Operators summary (" << model.operators.size()
+ << " operators): ";
+ std::unordered_multiset<OperatorType> ops_by_type;
+ for (const auto& op : model.operators) {
+ ops_by_type.insert(op->type);
+ }
+ auto it = ops_by_type.begin();
+ while (it != ops_by_type.end()) {
+ int count = ops_by_type.count(*it);
+ VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count;
+ std::advance(it, count);
+ }
+}
+
+void LogArray(int log_level, const Model& model, const string& name) {
+ const auto& array = model.GetArray(name);
+ VLOG(log_level) << "Array: " << name;
+ switch (array.data_type) {
+ case ArrayDataType::kNone:
+ break;
+ case ArrayDataType::kFloat:
+ VLOG(log_level) << " Data type: kFloat";
+ break;
+ case ArrayDataType::kInt32:
+ VLOG(log_level) << " Data type: kInt32";
+ break;
+ case ArrayDataType::kUint8:
+ VLOG(log_level) << " Data type: kUint8";
+ break;
+ default:
+ VLOG(log_level) << " Data type: other (numerical value: "
+ << static_cast<int>(array.data_type) << ")";
+ break;
+ }
+ if (array.buffer) {
+ VLOG(log_level) << " Constant Buffer";
+ }
+ if (array.alloc) {
+ VLOG(log_level) << " Transient Alloc";
+ }
+ if (array.has_shape()) {
+ const Shape& array_shape = array.shape();
+ if (array_shape.dimensions_count() == 0) {
+ VLOG(log_level) << " (Zero dimensions)";
+ } else {
+ string message = " Dims: ";
+ bool first = true;
+ for (const int dim : array_shape.dims()) {
+ if (!first) {
+ message += ", ";
+ }
+ first = false;
+ toco::port::AppendF(&message, "%d", dim);
+ }
+ VLOG(log_level) << message;
+ }
+ }
+ if (array.minmax) {
+ VLOG(log_level) << " MinMax: " << array.minmax->min << " .. "
+ << array.minmax->max;
+ }
+ if (array.quantization_params) {
+ VLOG(log_level) << " QuantizationParams: zero_point="
+ << array.quantization_params->zero_point
+ << ", scale=" << array.quantization_params->scale;
+ }
+}
+
+void DumpGraphvizVideoFrame(const Model& model) {
+ namespace port = toco::port;
+
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+ if (!dump_options.dump_graphviz_video) {
+ return;
+ }
+ CHECK(!dump_options.dump_graphviz.empty());
+ // TODO(benoitjacob): the static data here means that this function
+ // is stateful, not reentrant, and effectively leaks memory till exit
+ // (since dump_hashes can only grow in size). It also means that it
+ // really only is intended to be called for a single model during the
+ // process' lifetime. So it's not great design at all. The overriding
+ // design aspect here is to make the video-dumping code as unintrusive
+ // and self-contained as possible. Eventually, we'll want to have that
+ // cleaned-up, but that will require some form of general statefulness
+ // in toco (some kind of 'tooling state' data structure) that does
+ // not exist at present, and would be premature to design here just for
+ // this new video-dumping feature.
+ static int dump_id = 0;
+ static std::unordered_set<std::size_t> dump_hashes;
+ string graphviz_dump;
+ DumpGraphviz(model, &graphviz_dump);
+ std::size_t hash = std::hash<string>{}(graphviz_dump);
+ if (!dump_hashes.count(hash)) {
+ dump_hashes.insert(hash);
+ CHECK(port::file::SetContents(
+ port::file::JoinPath(
+ dump_options.dump_graphviz,
+ toco::port::StringF("toco_video_%05d.dot", dump_id)),
+ graphviz_dump, port::file::Defaults())
+ .ok());
+ dump_id++;
+ }
+}
+
+void LogDump(int log_level, const string& message, const Model& model) {
+ namespace port = toco::port;
+ const auto& dump_options = *GraphVizDumpOptions::singleton();
+
+ DumpGraphvizVideoFrame(model);
+ if (!dump_options.dump_graphviz.empty()) {
+ string graphviz_dump;
+
+ DumpGraphviz(model, &graphviz_dump);
+ CHECK(port::file::SetContents(
+ port::file::JoinPath(
+ dump_options.dump_graphviz,
+ absl::StrCat("toco_",
+ absl::StrReplaceAll(message, {{" ", "_"}}),
+ ".dot")),
+ graphviz_dump, port::file::Defaults())
+ .ok());
+ }
+
+ if (!VLOG_IS_ON(log_level)) {
+ return;
+ }
+ VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
+ LogSummary(log_level, model);
+ std::unordered_set<string> already_printed_arrays;
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ if (!already_printed_arrays.count(input)) {
+ already_printed_arrays.insert(input);
+ LogArray(log_level, model, input);
+ }
+ }
+ VLOG(log_level) << HelpfulOperatorTypeName(*op) << " : ";
+ VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> "
+ << FormatArraysList(model, op->outputs);
+ if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ VLOG(log_level) << " (with fused activation function)";
+ }
+ for (const auto& output : op->outputs) {
+ if (!already_printed_arrays.count(output)) {
+ already_printed_arrays.insert(output);
+ LogArray(log_level, model, output);
+ }
+ }
+ }
+ VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
+}
+
+// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
+void ExtendShape(Shape* shape, int new_shape_size) {
+ CHECK_GE(new_shape_size, shape->dimensions_count());
+ const int size_increase = new_shape_size - shape->dimensions_count();
+ auto* shape_dims = shape->mutable_dims();
+ shape_dims->insert(shape_dims->begin(), size_increase, 1);
+}
+
+// TODO(b/62904716) Remove along with remaining uses.
+void UnextendShape(Shape* shape, int new_shape_size) {
+ CHECK_LE(new_shape_size, shape->dimensions_count());
+ const int size_reduction = shape->dimensions_count() - new_shape_size;
+ for (int i = 0; i < size_reduction; i++) {
+ CHECK_EQ(shape->dims(i), 1);
+ }
+ std::vector<int>& shape_dims = *shape->mutable_dims();
+ shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
+}
+
+void CheckShapeDimensions(const Shape& shape) {
+ for (int i = 0; i < shape.dimensions_count(); ++i) {
+ CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
+ << ". shape = " << ShapeToString(shape);
+ }
+}
+
+bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
+ CheckShapeDimensions(shape0);
+ CheckShapeDimensions(shape1);
+
+ const Shape* longer = &shape0;
+ const Shape* shorter = &shape1;
+ if (shape1.dimensions_count() > shape0.dimensions_count()) {
+ longer = &shape1;
+ shorter = &shape0;
+ }
+
+ // Walk dimensions back to front until we run out of dimensions in the shorter
+ // shape.
+ int longer_index = longer->dimensions_count() - 1;
+ int shorter_index = shorter->dimensions_count() - 1;
+ while (shorter_index >= 0) {
+ const int d_long = longer->dims(longer_index);
+ const int d_short = shorter->dims(shorter_index);
+ // Broadcasting fails if the dimensions are different *and* neither is 1.
+ if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
+ return false;
+ }
+ longer_index--;
+ shorter_index--;
+ }
+ return true;
+}
+
+bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
+ CheckShapeDimensions(shape0);
+ CheckShapeDimensions(shape1);
+
+ const Shape* longer = &shape0;
+ const Shape* shorter = &shape1;
+ if (shape1.dimensions_count() > shape0.dimensions_count()) {
+ longer = &shape1;
+ shorter = &shape0;
+ }
+
+ // Walk dimensions back to front until we run out of dimensions in the shorter
+ // shape.
+ int longer_index = longer->dimensions_count() - 1;
+ int shorter_index = shorter->dimensions_count() - 1;
+ while (shorter_index >= 0) {
+ const int d_long = longer->dims(longer_index);
+ const int d_short = shorter->dims(shorter_index);
+ // Extending fails if the dimensions are different.
+ if (d_long != d_short) {
+ return false;
+ }
+ longer_index--;
+ shorter_index--;
+ }
+
+ // The remaining dimensions in the longer shape must be 1.
+ while (longer_index >= 0) {
+ const int d_long = longer->dims(longer_index);
+ if (d_long != 1) {
+ return false;
+ }
+ longer_index--;
+ }
+
+ return true;
+}
+
+int RequiredBufferSizeForShape(const Shape& shape) {
+ int max_offset = 1;
+ for (const auto& dim : shape.dims()) {
+ CHECK_GE(dim, 1);
+ max_offset *= dim;
+ }
+ return max_offset;
+}
+
+bool IsConstantParameterArray(const Model& model, const string& name) {
+ if (!model.arrays.count(name)) {
+ return false;
+ }
+
+ return !!model.arrays.at(name)->buffer;
+}
+
+void CheckNoMissingArray(const Model& model) {
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ CHECK(model.arrays.count(input));
+ }
+ for (const auto& output : op->outputs) {
+ CHECK(model.arrays.count(output));
+ }
+ }
+ for (const auto& input_array : model.flags.input_arrays()) {
+ CHECK(model.arrays.count(input_array.name()))
+ << "Input array not found: " << input_array.name();
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ CHECK(model.arrays.count(output_array))
+ << "Output array not found: " << output_array;
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ CHECK(model.arrays.count(rnn_state.state_array()));
+ CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
+ }
+}
+
+void FixNoMissingArray(Model* model) {
+ for (const auto& op : model->operators) {
+ for (const auto& input : op->inputs) {
+ if (!model->arrays.count(input)) {
+ model->GetOrCreateArray(input);
+ }
+ }
+ for (const auto& output : op->outputs) {
+ if (!model->arrays.count(output)) {
+ model->GetOrCreateArray(output);
+ }
+ }
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (!model->arrays.count(output_array)) {
+ model->GetOrCreateArray(output_array);
+ }
+ }
+}
+
+void CheckNoOrphanedArray(const Model& model) {
+ std::unordered_set<string> arrays_without_known_use;
+ for (const auto& array : model.arrays) {
+ arrays_without_known_use.insert(array.first);
+ }
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ arrays_without_known_use.erase(input);
+ }
+ for (const auto& output : op->outputs) {
+ arrays_without_known_use.erase(output);
+ }
+ }
+ if (!arrays_without_known_use.empty()) {
+ for (const auto& array : arrays_without_known_use) {
+ LOG(INFO) << "Error: Orphaned array: " << array;
+ }
+ }
+ CHECK(arrays_without_known_use.empty());
+}
+
+void FixNoOrphanedArray(Model* model) {
+ std::unordered_set<string> arrays_without_known_use;
+ for (const auto& array : model->arrays) {
+ arrays_without_known_use.insert(array.first);
+ }
+ for (const auto& op : model->operators) {
+ for (const auto& input : op->inputs) {
+ arrays_without_known_use.erase(input);
+ }
+ for (const auto& output : op->outputs) {
+ arrays_without_known_use.erase(output);
+ }
+ }
+ for (const auto& array : arrays_without_known_use) {
+ model->arrays.erase(array);
+ }
+}
+
+void CheckArrayFieldsConsistent(const Model& model) {
+ for (const auto& array_entry : model.arrays) {
+ const auto& array = array_entry.second;
+ if (array->has_shape()) {
+ for (int d : array->shape().dims()) {
+ CHECK_GE(d, 1);
+ }
+ }
+ // It's OK to have a buffer or an alloc, but not both.
+ // (Since allocs are for transient arrays without a buffer).
+ CHECK(!array->buffer || !array->alloc);
+ // If there is a buffer, its type should be consistent with data_type.
+ if (array->buffer) {
+ CHECK(array->buffer->type == array->data_type);
+ }
+ }
+}
+
+void CheckOperatorOrdering(const Model& model) {
+ std::unordered_set<string> arrays_behind_us;
+ for (const auto& array_entry : model.arrays) {
+ if (!GetOpWithOutput(model, array_entry.first)) {
+ arrays_behind_us.insert(array_entry.first);
+ }
+ }
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ if (!IsConstantParameterArray(model, input)) {
+ CHECK(arrays_behind_us.count(input));
+ }
+ }
+ for (const auto& output : op->outputs) {
+ CHECK(!arrays_behind_us.count(output));
+ arrays_behind_us.insert(output);
+ }
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ CHECK(arrays_behind_us.count(output_array));
+ }
+}
+
+void FixOperatorOrdering(Model* model) {
+ std::unordered_set<string> arrays_behind_us;
+ for (const auto& array_entry : model->arrays) {
+ if (!GetOpWithOutput(*model, array_entry.first)) {
+ arrays_behind_us.insert(array_entry.first);
+ }
+ }
+ std::vector<std::unique_ptr<Operator>> old_operators;
+ std::swap(old_operators, model->operators);
+ std::set<std::size_t> remaining;
+ for (std::size_t i = 0; i < old_operators.size(); i++) {
+ remaining.insert(i);
+ }
+ std::unordered_map<string, string> reason_why_leftover;
+ while (true) {
+ bool inserted_something = false;
+ for (auto i : remaining) {
+ bool can_insert = true;
+ auto& op = old_operators[i];
+ CHECK(op.get());
+ for (const auto& input : op->inputs) {
+ if (!IsConstantParameterArray(*model, input) &&
+ !arrays_behind_us.count(input)) {
+ for (const string& output : op->outputs) {
+ reason_why_leftover[output] = input;
+ }
+ can_insert = false;
+ break;
+ }
+ }
+ if (can_insert) {
+ model->operators.emplace_back(nullptr);
+ for (const auto& output : op->outputs) {
+ arrays_behind_us.insert(output);
+ }
+ std::swap(op, model->operators.back());
+ remaining.erase(i);
+ inserted_something = true;
+ break;
+ }
+ }
+ if (!inserted_something) {
+ break;
+ }
+ }
+ if (!remaining.empty()) {
+ LOG(ERROR)
+ << "No viable ordering of operators was found. "
+ << "Here is a 'backtrace' of at least one part of the graph that is "
+ << "problematic. It starts with the first operator that has as "
+ << "problematic input array, and then walks back the graph to "
+ << "the operator that produced that input array, etc., until we find "
+ << "the root cause:";
+ LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
+ LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
+ const Operator* bad_op = old_operators[*remaining.begin()].get();
+ std::unordered_set<string> bad_inputs_already_traced;
+ // The following while(true) loop should always end with a LOG(FATAL).
+ while (true) {
+ LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
+ << FormatArraysList(*model, bad_op->inputs) << " -> "
+ << FormatArraysList(*model, bad_op->outputs);
+ bool found_bad_output = false;
+ string bad_output;
+ for (const string& output : bad_op->outputs) {
+ if (reason_why_leftover.count(output)) {
+ found_bad_output = true;
+ bad_output = output;
+ break;
+ }
+ }
+ CHECK(found_bad_output);
+ const string& bad_input = reason_why_leftover[bad_output];
+ LOG(ERROR) << "The bad input here is: " << bad_input;
+ if (bad_inputs_already_traced.count(bad_input)) {
+ LOG(FATAL)
+ << "Cycle found! We already encountered that "
+ << "input array, " << bad_input << ", earlier in the "
+ << "above trace! We expect graphs to be acyclic, even "
+ << "RNNs. Let us know if some graph actually needs to have "
+ << "cycles, but first, please check if it really is "
+ << "an *inference* graph. *Training* graphs are out-of-scope "
+ << "for toco.";
+ }
+ bad_inputs_already_traced.insert(bad_input);
+ bad_op = nullptr;
+ for (auto i : remaining) {
+ const Operator* op = old_operators[i].get();
+ for (const string& output : op->outputs) {
+ if (bad_input == output) {
+ bad_op = op;
+ break;
+ }
+ }
+ if (bad_op) {
+ break;
+ }
+ }
+ if (!bad_op) {
+ LOG(ERROR) << "And that's the root cause: "
+ << "that array, " << bad_input << ", isn't produced by any "
+ << "operator, or provided in any other way.";
+ LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
+ LOG(FATAL) << "(The above was a multi-line fatal error)";
+ }
+ LOG(ERROR) << "And that array is the output of the following operator:";
+ }
+ }
+ CHECK(remaining.empty())
+ << "Should never get here! In case of bad graph, "
+ << "the above code should have generated a FATAL error already!";
+}
+
+// Checks that the --input_arrays of the Model are actually used by at least
+// one of the --output_arrays i.e. that the graph contains a path from each one
+// of the inputs to at least one of the outputs. This catches cases where the
+// user passed the wrong --input_arrays or --output_arrays, which otherwise may
+// result in cryptic error messages.
+void CheckInputUsedByOutputs(const Model& model) {
+ std::set<string> used_arrays;
+ for (const string& output : model.flags.output_arrays()) {
+ used_arrays.insert(output);
+ }
+ for (int i = model.operators.size() - 1; i >= 0; i--) {
+ bool is_op_used = false;
+ for (const string& op_output : model.operators[i]->outputs) {
+ if (used_arrays.count(op_output)) {
+ is_op_used = true;
+ break;
+ }
+ }
+ if (!is_op_used) {
+ continue;
+ }
+ for (const string& op_input : model.operators[i]->inputs) {
+ used_arrays.insert(op_input);
+ }
+ }
+ for (const auto& input_array : model.flags.input_arrays()) {
+ QCHECK(used_arrays.count(input_array.name()))
+ << "The graph does not connect the input (" << input_array.name()
+ << ") specified by --input_arrays to any of the specified "
+ << "--output_arrays ("
+ << absl::StrJoin(model.flags.output_arrays(), ", ")
+ << "). Did you pass the wrong flags for this model, "
+ << "or is that model's graph actually incomplete?";
+ }
+}
+
+void CheckInvariants(const Model& model) {
+ CheckNoMissingArray(model);
+ CheckNoOrphanedArray(model);
+ CheckArrayFieldsConsistent(model);
+ CheckOperatorOrdering(model);
+ CheckInputUsedByOutputs(model);
+}
+
+void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
+ const int count, const string& count_description) {
+ if (model_check.count_min() >= 0) {
+ CHECK_GE(count, model_check.count_min())
+ << "Mismatch in " << count_description << ": count was " << count
+ << ", but the specified "
+ << (model_check.count_max() > model_check.count_min() ? "minimum"
+ : "value")
+ << " was " << model_check.count_min() << ".";
+ }
+ if (model_check.count_max() > model_check.count_min()) {
+ CHECK_LE(count, model_check.count_max())
+ << "Mismatch in " << count_description << ": count was " << count
+ << ", but the specified maximum was " << model_check.count_max() << ".";
+ }
+}
+
+void CheckModelCounts(const Model& model) {
+ std::unordered_multiset<OperatorType> ops_by_type;
+ std::unordered_map<string, OperatorType> op_type_by_name;
+ if (model.flags.model_checks_size() == 0) {
+ return;
+ }
+
+ for (const auto& op : model.operators) {
+ ops_by_type.insert(op->type);
+ op_type_by_name[OperatorTypeName(op->type)] = op->type;
+ }
+ for (const auto& model_check : model.flags.model_checks()) {
+ string count_type = model_check.count_type();
+ if (count_type == "None") {
+ continue;
+ } else if (count_type == "Arrays") {
+ CheckCountInRange(model_check, model.arrays.size(), "count of arrays");
+ } else if (count_type == "Total") {
+ CheckCountInRange(model_check, model.operators.size(),
+ "count of all operator instances");
+ } else {
+ // The check type is not itself checked against the set of valid
+ // operators, mainly because the enum set cannot be iterated in C++.
+ const int found_count =
+ op_type_by_name.count(count_type) > 0
+ ? ops_by_type.count(op_type_by_name[count_type])
+ : 0;
+ CheckCountInRange(model_check, found_count,
+ "count of instances of " + count_type + " operator");
+ }
+ }
+}
+
+void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
+ std::vector<int>* out_dims) {
+ CHECK(out_dims->empty());
+ if (num_dims == 1) {
+ CHECK_EQ(batch, 1);
+ *out_dims = {depth};
+ } else if (num_dims == 2) {
+ *out_dims = {batch, depth};
+ } else if (num_dims == 3) {
+ CHECK_EQ(batch, 1);
+ *out_dims = {height, width, depth};
+ } else if (num_dims == 4) {
+ *out_dims = {batch, height, width, depth};
+ } else {
+ LOG(FATAL) << "Should not get here: " << num_dims;
+ }
+}
+
+void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
+ int batch = 1;
+ int num_dims = -1;
+ for (const auto& input_array : model->flags.input_arrays()) {
+ // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
+ // a better match by name.
+ if (input_array.name() == name || num_dims == -1) {
+ num_dims = input_array.shape_size();
+ if (num_dims != 0) {
+ batch = input_array.shape(0);
+ }
+ }
+ }
+ Array& array = model->GetOrCreateArray(name);
+ if (array.has_shape()) {
+ num_dims = array.shape().dimensions_count();
+ }
+ std::vector<int> dims;
+ MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
+ CHECK(array.data_type == ArrayDataType::kFloat ||
+ array.data_type == ArrayDataType::kNone);
+ array.data_type = ArrayDataType::kFloat;
+ if (!array.has_shape()) {
+ Shape* shape = array.mutable_shape();
+ *shape->mutable_dims() = dims;
+ }
+}
+
+void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
+ // Merge info about input_arrays from model_flags into model->flags
+ for (const auto& specified_input_array : model_flags.input_arrays()) {
+ toco::InputArray* dst_input_array = nullptr;
+ for (int i = 0; i < model->flags.input_arrays_size(); i++) {
+ toco::InputArray* candidate_dst_input_array =
+ model->flags.mutable_input_arrays(i);
+ if (candidate_dst_input_array->name() == specified_input_array.name()) {
+ // specified_input_array from model_flags maps to dst_input_array
+ // in model->flags
+ dst_input_array = candidate_dst_input_array;
+ break;
+ }
+ }
+ if (!dst_input_array) {
+ // specified_input_array from model_flags is not found in model->flags.
+ // Match a name-less specified input array when there can be no ambiguity
+ // as there is only 1 input array.
+ if (model->flags.input_arrays_size() == 1 &&
+ model_flags.input_arrays_size() == 1 &&
+ !specified_input_array.has_name()) {
+ dst_input_array = model->flags.mutable_input_arrays(0);
+ }
+ }
+ if (!dst_input_array) {
+ // Still no match, so create a new input array to copy
+ // specified_input_array into.
+ dst_input_array = model->flags.add_input_arrays();
+ dst_input_array->set_name(specified_input_array.name());
+ }
+
+#define RESOLVE_MODEL_FLAG(field_name) \
+ if (specified_input_array.has_##field_name()) { \
+ if (dst_input_array->has_##field_name()) { \
+ QCHECK_EQ(dst_input_array->field_name(), \
+ specified_input_array.field_name()) \
+ << "For input array '" << dst_input_array->name() << "', " \
+ << "specified " #field_name " flag with value: " \
+ << specified_input_array.field_name() \
+ << " does not agree with already defined " #field_name \
+ " of this model, with value: " \
+ << specified_input_array.field_name(); \
+ } else { \
+ dst_input_array->set_##field_name(specified_input_array.field_name()); \
+ } \
+ }
+ RESOLVE_MODEL_FLAG(std_value);
+ RESOLVE_MODEL_FLAG(mean_value);
+#undef RESOLVE_MODEL_FLAG
+
+ if (!specified_input_array.shape().empty()) {
+ if (!dst_input_array->shape().empty()) {
+ QCHECK_EQ(specified_input_array.shape().size(),
+ dst_input_array->shape().size())
+ << "For input array '" << specified_input_array.name() << "', "
+ << "size of specified input shape flag with size: "
+ << specified_input_array.shape().size()
+ << " does not agree with already defined input shape"
+ " of this model, with size: "
+ << dst_input_array->shape().size();
+ // We treat the first dimension as a special case, since it is often
+ // a batch size and the input_shape flag is effectively overriding
+ // the model.
+ for (int i = 1; i < specified_input_array.shape().size(); i++) {
+ QCHECK_EQ(specified_input_array.shape().Get(i),
+ dst_input_array->shape().Get(i))
+ << "At dimension number " << i << " of input array "
+ << specified_input_array.name() << ", the specified shape's "
+ << "dimension flag with dimension: "
+ << specified_input_array.shape().Get(i)
+ << " does not agree with already defined shape"
+ << " of this model, with dimension: "
+ << dst_input_array->shape().Get(i);
+ }
+ } else {
+ dst_input_array->mutable_shape()->CopyFrom(
+ specified_input_array.shape());
+ }
+ }
+ }
+
+ if (model_flags.output_arrays_size() > 0) {
+ model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
+ }
+
+#define RESOLVE_MODEL_FLAG(name) \
+ if (model_flags.has_##name()) { \
+ if (model->flags.has_##name()) { \
+ QCHECK_EQ(model_flags.name(), model->flags.name()) \
+ << "Specified " #name " flag with value: " << model_flags.name() \
+ << " does not agree with already defined " #name \
+ " of this model, with value: " \
+ << model->flags.name(); \
+ } else { \
+ model->flags.set_##name(model_flags.name()); \
+ } \
+ }
+
+ RESOLVE_MODEL_FLAG(variable_batch)
+ RESOLVE_MODEL_FLAG(drop_control_dependency)
+
+#undef RESOLVE_MODEL_FLAG
+
+ if (model->flags.rnn_states_size() == 0) {
+ model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
+ } else {
+ CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size());
+ for (int i = 0; i < model->flags.rnn_states_size(); i++) {
+ CHECK_EQ(model->flags.rnn_states(i).state_array(),
+ model_flags.rnn_states(i).state_array());
+ CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(),
+ model_flags.rnn_states(i).back_edge_source_array());
+ }
+ }
+
+ if (model->flags.model_checks_size() == 0) {
+ model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
+ }
+
+ QCHECK_GT(model->flags.input_arrays_size(), 0)
+ << "This model does not define input arrays, so a "
+ "--input_arrays flag must be given on the command-line.";
+ QCHECK_GT(model->flags.output_arrays_size(), 0)
+ << "This model does not define output arrays, so a "
+ "--output_arrays flag must be given on the command-line.";
+
+ for (const auto& input_array_proto : model->flags.input_arrays()) {
+ QCHECK(!input_array_proto.shape().empty())
+ << "This model does not have shape defined for input array "
+ << input_array_proto.name()
+ << ", so one must be specified by a non-empty --input_shape "
+ "command-line flag.";
+
+ auto& input_array = model->GetOrCreateArray(input_array_proto.name());
+ if (input_array.data_type == ArrayDataType::kNone) {
+ // We start out with a float input array;
+ // that may get replaced by a uint8 array later, by
+ // MakeInitialDequantizeOp.
+ input_array.data_type = ArrayDataType::kFloat;
+ }
+
+ // Compare/merge the model->flags describing the input_shape with
+ // the actual input array's shape.
+ auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
+ if (input_array_dims.empty()) {
+ for (auto dim : input_array_proto.shape()) {
+ CHECK_GE(dim, 1);
+ input_array_dims.push_back(dim);
+ }
+ } else {
+ CHECK_EQ(input_array_dims.size(), input_array_proto.shape_size());
+ for (int i = 0; i < input_array_dims.size(); i++) {
+ CHECK_EQ(input_array_dims[i], input_array_proto.shape(i));
+ }
+ }
+
+ const float mean_value = input_array_proto.mean_value();
+ const float std_value = input_array_proto.std_value();
+ MinMax input_minmax;
+ input_minmax.min = (0.f - mean_value) / std_value;
+ input_minmax.max = (255.f - mean_value) / std_value;
+ if (input_array.minmax) {
+ if (input_array_proto.has_mean_value() ||
+ input_array_proto.has_std_value()) {
+ CHECK(input_minmax == *input_array.minmax)
+ << input_minmax.min << ", " << input_minmax.max
+ << " != " << input_array.minmax->min << ", "
+ << input_array.minmax->max;
+ }
+ } else {
+ input_array.GetOrCreateMinMax() = input_minmax;
+ }
+ }
+ // Creation of the RNN state arrays
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (!rnn_state.manually_create()) {
+ continue;
+ }
+ CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
+ model);
+ }
+}
+
+void CheckIsReadyForQuantization(const Model& model) {
+ for (const auto& op : model.operators) {
+ for (const auto& input : op->inputs) {
+ const auto& input_array = model.GetArray(input);
+ if (input_array.data_type != ArrayDataType::kFloat) {
+ // The array is not floats, no quantization needed.
+ continue;
+ }
+ if (input_array.minmax) {
+ // The array has minmax, we're good.
+ continue;
+ }
+ if (input_array.buffer) {
+ // The array has a constant buffer, so we can
+ // fall back to computing the minmax from actual array entries
+ // (with a WARNING about possible accuracy implications).
+ continue;
+ }
+ LOG(FATAL)
+ << "Array " << input << ", which is an input to the "
+ << HelpfulOperatorTypeName(*op) << " operator producing the output "
+ << "array " << op->outputs[0] << ", is lacking min/max data, "
+ << "which is necessary for quantization. Either target a "
+ << "non-quantized output format, or change the input graph to "
+ << "contain min/max information, or pass --default_ranges_min= and "
+ << "--default_ranges_max= if you do not care about the accuracy of "
+ << "results.";
+ }
+ }
+}
+
+void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
+ double default_ranges_max) {
+ for (const auto& op : model->operators) {
+ for (const auto& input : op->inputs) {
+ auto& input_array = model->GetArray(input);
+ if (!input_array.minmax && !input_array.buffer) {
+ auto& minmax = input_array.GetOrCreateMinMax();
+ minmax.min = default_ranges_min;
+ minmax.max = default_ranges_max;
+ }
+ }
+ for (const auto& output : op->outputs) {
+ auto& output_array = model->GetArray(output);
+ if (!output_array.minmax && !output_array.buffer) {
+ auto& minmax = output_array.GetOrCreateMinMax();
+ minmax.min = default_ranges_min;
+ minmax.max = default_ranges_max;
+ }
+ }
+ }
+}
+
+int ElementSize(ArrayDataType data_type) {
+ switch (data_type) {
+ case ArrayDataType::kFloat:
+ return 4;
+ case ArrayDataType::kInt32:
+ return 4;
+ case ArrayDataType::kUint8:
+ return 1;
+ default:
+ LOG(FATAL) << "Should not get here.";
+ return 0;
+ }
+}
+
+void DropMinMax(Model* model, const string& array_name) {
+ auto& array = model->GetArray(array_name);
+ if (!!array.minmax) {
+ LOG(WARNING) << "Dropping MinMax information in array " << array_name
+ << ". Expect inaccuracy in quantized inference.";
+ array.minmax = nullptr;
+ }
+}
+
+bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
+ // The model's input and output arrays are externally allocated.
+ // They are not transient arrays.
+ if (IsInputArray(model, array_name)) {
+ return false;
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
+ return false;
+ }
+ }
+ const auto& array = model.arrays.at(array_name);
+ // An array with a constant buffer isn't a transient array.
+ if (!!array->buffer) {
+ return false;
+ }
+ // An array without shape isn't allocatable.
+ if (!array->has_shape()) {
+ return false;
+ }
+ return true;
+}
+
+string AvailableArrayName(const Model& model, const string& name) {
+ if (!model.arrays.count(name)) {
+ return name;
+ }
+ const int kNumSuffixesToTry = 1000;
+ for (int i = 0; i < kNumSuffixesToTry; i++) {
+ const string& name_with_suffix = toco::port::StringF("%s_%d", name, i);
+ if (!model.arrays.count(name_with_suffix)) {
+ return name_with_suffix;
+ }
+ }
+ LOG(FATAL) << "Could not find an available array name starting with " << name
+ << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!";
+ return "";
+}
+
+string ShapeToString(const Shape& shape) {
+ if (shape.dimensions_count() == 0) {
+ return "[]";
+ }
+
+ return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
+}
+
+void PrintArrayShape(Model* model, const string& name) {
+ if (!model->arrays[name]->has_shape()) {
+ LOG(INFO) << name << " has no shape";
+ return;
+ }
+ LOG(INFO) << name
+ << " has shape: " << ShapeToString(model->arrays[name]->shape());
+}
+
+bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
+ bool is_fc_weights = false;
+ bool is_something_else = false;
+ for (const auto& op : model.operators) {
+ for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
+ if (op->inputs[input_index] == name) {
+ if (op->type == OperatorType::kFullyConnected && input_index == 1) {
+ is_fc_weights = true;
+ } else {
+ is_something_else = true;
+ }
+ }
+ }
+ }
+ CHECK(!(is_fc_weights && is_something_else));
+ return is_fc_weights;
+}
+
+bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
+ int64 total = 0;
+ for (const auto& op : model.operators) {
+ switch (op->type) {
+ case OperatorType::kFullyConnected:
+ case OperatorType::kConv:
+ case OperatorType::kDepthwiseConv: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ const auto& weights_array = model.GetArray(op->inputs[1]);
+ if (!output_array.has_shape() || !weights_array.has_shape()) {
+ return false;
+ }
+ int cols = 1;
+ for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
+ cols *= output_array.shape().dims(i);
+ }
+ const int64 cost_per_col =
+ 2 * RequiredBufferSizeForShape(weights_array.shape());
+ total += cost_per_col * cols;
+ if (op->inputs.size() > 2) {
+ // There is a bias vector. One more op per output value.
+ total += RequiredBufferSizeForShape(output_array.shape());
+ }
+ break;
+ }
+ case OperatorType::kAdd:
+ case OperatorType::kSub:
+ case OperatorType::kMul: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ total += RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
+ case OperatorType::kLogistic:
+ case OperatorType::kSoftmax:
+ case OperatorType::kTanh: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // As a very rough ballpark, the cost of evaluating a math function
+ // such as tanh or logistic is about 32 multiplications, and about as
+ // many additions/subtractions. (Just a power-of-two order-of-magnitude
+ // from looking at actual implementations that we use in runtime/ code).
+ total += 64 * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
+ case OperatorType::kMaxPool: {
+ const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get());
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ total += RequiredBufferSizeForShape(output_array.shape()) *
+ maxpool.kheight * maxpool.kwidth;
+ break;
+ }
+ case OperatorType::kAveragePool: {
+ const auto& avgpool =
+ *static_cast<const AveragePoolOperator*>(op.get());
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ total += RequiredBufferSizeForShape(output_array.shape()) *
+ avgpool.kheight * avgpool.kwidth;
+ break;
+ }
+ case OperatorType::kL2Pool: {
+ const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get());
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // The sum of squares requires (kheight*kwidth) multiply-adds,
+ // and then there is the sqrt which we ballpark at 32 ops.
+ const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
+ total +=
+ RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
+ break;
+ }
+ case OperatorType::kL2Normalization: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // Computing the squared L2 norm is N multiply-adds so 2N ops,
+ // then the single inverse-sqrt is negligible, then we multiply each
+ // value by the resulting multiplier, so an extra N ops. Total 3N ops.
+ total += 3 * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ *result = total;
+ return true;
+}
+
+namespace {
+
+void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
+ std::vector<int>* shuffle) {
+ CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
+ shuffle->resize(4);
+ for (int i = 0; i < 4; i++) {
+ (*shuffle)[i] = i;
+ }
+ if (input_axes_order == output_axes_order) {
+ // nothing to do
+ } else if (AxesCount(input_axes_order) == 2) {
+ shuffle->resize(2);
+ (*shuffle)[0] = 1;
+ (*shuffle)[1] = 0;
+ } else if (input_axes_order == AxesOrder::kOHWI &&
+ output_axes_order == AxesOrder::kHWIO) {
+ // 3210 <- 3210
+ // HWIO <- OHWI
+ (*shuffle)[0] = 1;
+ (*shuffle)[1] = 2;
+ (*shuffle)[2] = 3;
+ (*shuffle)[3] = 0;
+ } else if (input_axes_order == AxesOrder::kHWIO &&
+ output_axes_order == AxesOrder::kOHWI) {
+ // 3210 <- 3210
+ // OHWI <- HWIO
+ (*shuffle)[0] = 3;
+ (*shuffle)[1] = 0;
+ (*shuffle)[2] = 1;
+ (*shuffle)[3] = 2;
+ } else {
+ LOG(FATAL) << "Bad shuffle";
+ }
+}
+
+// Extend shuffle is designed to match ExtendShape, which pads the shape with
+// unit dimensions at the beginning.
+void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
+ std::vector<int>* extended_shuffle) {
+ *extended_shuffle = input_shuffle;
+ CHECK(newdim >= input_shuffle.size());
+ const int pad_size = newdim - input_shuffle.size();
+ extended_shuffle->resize(newdim);
+ for (int i = 0; i < pad_size; i++) {
+ (*extended_shuffle)[i] = i;
+ }
+ for (int i = pad_size; i < newdim; i++) {
+ (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
+ }
+}
+
+} // end anonymous namespace
+
+void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, Shape* output_shape) {
+ if (input_axes_order == AxesOrder::kHWIM &&
+ output_axes_order == AxesOrder::k1HWO) {
+ // This special case isn't just a permutation, the IM pair of dims get
+ // merged into the 3 dim, so we have to special-case it.
+ *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
+ input_shape.dims(3) * input_shape.dims(2)});
+ } else {
+ std::vector<int> shuffle;
+ GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
+ std::vector<int>* output_dims = output_shape->mutable_dims();
+ output_dims->resize(input_shape.dimensions_count());
+ for (int i = 0; i < input_shape.dimensions_count(); i++) {
+ (*output_dims)[i] = input_shape.dims(shuffle[i]);
+ }
+ }
+}
+
+void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, const Shape& output_shape,
+ const float* input_data, float* output_data) {
+ if (input_axes_order == AxesOrder::kHWIM &&
+ output_axes_order == AxesOrder::k1HWO) {
+ // This special case isn't just a permutation, the IM pair of dims get
+ // merged into the O dim, so we have to special-case it. Fortunately,
+ // as far as array shuffling is concerned, it's just the identity
+ // transformation.
+ memcpy(output_data, input_data,
+ RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
+ return;
+ }
+ CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
+ const int dim = input_shape.dimensions_count();
+ CHECK_LE(dim, 4);
+ std::vector<int> shuffle;
+ GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
+ CHECK(shuffle.size() >= dim);
+ for (int i = 0; i < dim; i++) {
+ CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
+ CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
+ }
+ Shape extended_input_shape = input_shape;
+ ExtendShape(&extended_input_shape, 4);
+ Shape extended_output_shape = output_shape;
+ ExtendShape(&extended_output_shape, 4);
+ std::vector<int> extended_shuffle;
+ ExtendShuffle(shuffle, 4, &extended_shuffle);
+
+ const std::vector<int>& extended_input_dims = extended_input_shape.dims();
+ const std::vector<int>& extended_output_dims = extended_output_shape.dims();
+
+ // TODO(starka): Rework to handle different numbers of dimensions.
+ int input_strides[4];
+ input_strides[3] = 1;
+ input_strides[2] = extended_input_dims[3];
+ input_strides[1] = input_strides[2] * extended_input_dims[2];
+ input_strides[0] = input_strides[1] * extended_input_dims[1];
+ const int input_stride_0 = input_strides[extended_shuffle[3]];
+ const int input_stride_1 = input_strides[extended_shuffle[2]];
+ const int input_stride_2 = input_strides[extended_shuffle[1]];
+ const int input_stride_3 = input_strides[extended_shuffle[0]];
+
+ const int output_size_0 = extended_output_dims[3];
+ const int output_size_1 = extended_output_dims[2];
+ const int output_size_2 = extended_output_dims[1];
+ const int output_size_3 = extended_output_dims[0];
+ const int output_stride_0 = 1;
+ const int output_stride_1 = output_size_0;
+ const int output_stride_2 = output_stride_1 * output_size_1;
+ const int output_stride_3 = output_stride_2 * output_size_2;
+
+ for (int i3 = 0; i3 < output_size_3; i3++) {
+ const float* const input_ptr_3 = input_data + i3 * input_stride_3;
+ float* const output_ptr_3 = output_data + i3 * output_stride_3;
+ for (int i2 = 0; i2 < output_size_2; i2++) {
+ const float* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
+ float* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
+ for (int i1 = 0; i1 < output_size_1; i1++) {
+ const float* input_ptr = input_ptr_2 + i1 * input_stride_1;
+ float* output_ptr = output_ptr_2 + i1 * output_stride_1;
+ float* const output_ptr_end =
+ output_ptr + output_size_0 * output_stride_0;
+ while (output_ptr != output_ptr_end) {
+ *output_ptr = *input_ptr;
+ input_ptr += input_stride_0;
+ output_ptr += output_stride_0;
+ }
+ }
+ }
+ }
+}
+
+int AxesCount(AxesOrder axes_order) {
+ switch (axes_order) {
+ case AxesOrder::kOneAxis:
+ return 1;
+ case AxesOrder::kRC:
+ return 2;
+ case AxesOrder::kCR:
+ return 2;
+ case AxesOrder::kHWIO:
+ return 4;
+ case AxesOrder::kOHWI:
+ return 4;
+ case AxesOrder::kHWIM:
+ return 4;
+ case AxesOrder::k1HWO:
+ return 4;
+ case AxesOrder::kNHWC:
+ return 4;
+ default:
+ LOG(FATAL) << "Bad AxesOrder";
+ return 0;
+ }
+}
+
+bool IsDiscardableArray(const Model& model, const string& array_name) {
+ for (const auto& input_array : model.flags.input_arrays()) {
+ if (array_name == input_array.name()) {
+ return false;
+ }
+ }
+ for (const string& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
+ return false;
+ }
+ }
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (array_name == rnn_state.state_array()) {
+ return false;
+ }
+ if (array_name == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void CheckFinalDataTypesSatisfied(const Model& model) {
+ for (const auto& array_entry : model.arrays) {
+ const auto& array = *array_entry.second;
+ if (array.final_data_type != ArrayDataType::kNone) {
+ CHECK(array.final_data_type == array.data_type);
+ }
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
new file mode 100644
index 0000000000..093945edb3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -0,0 +1,292 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+
+#include <algorithm>
+#include <cmath>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "google/protobuf/text_format.h"
+#include "tensorflow/core/platform/logging.h"
+#if TOCO_SUPPORT_PORTABLE_PROTOS
+#include "third_party/protobuf/src/google/protobuf/text_format.h"
+#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+
+// TODO(aselle): Replace with using a container specific hash override instead.
+namespace std {
+template <>
+struct hash<toco::OperatorType> {
+ size_t operator()(const toco::OperatorType& op) const {
+ return std::hash<size_t>()(static_cast<size_t>(op));
+ }
+};
+} // namespace std
+
+namespace toco {
+
+constexpr int kLogLevelModelChanged = 1;
+constexpr int kLogLevelModelUnchanged = 2;
+
+string LogName(const Operator& op);
+
+bool IsInputArray(const Model& model, const string& name);
+bool IsArrayConsumed(const Model& model, const string& name);
+int CountTrueOutputs(const Model& model, const Operator& op);
+
+int CountOpsWithInput(const Model& model, const string& array_name);
+bool DeleteArrayIfUnused(const string& array_name, Model* model);
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
+ const Model& model, const string& array_name);
+Operator* GetOpWithOutput(const Model& model, const string& array_name);
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
+ Model& model, const string& array_name);
+Operator* GetOpWithOutput(const Model& model, const string& array_name);
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
+ const Model& model, const string& array_name);
+Operator* GetOpWithInput(const Model& model, const string& array_name);
+Operator* GetFirstOpWithInput(const Model& model, const string& array_name);
+
+std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
+ const Model& model, const Operator* op);
+std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
+ const Operator* op);
+
+const char* OperatorTypeName(OperatorType type);
+string HelpfulOperatorTypeName(const Operator& op);
+
+void DumpGraphvizVideoFrame(const Model& model);
+void LogDump(int log_level, const string& message, const Model& model);
+void LogSummary(int log_level, const string& message, const Model& model);
+
+inline bool ParseFromStringOverload(const std::string& in,
+ TFLITE_PROTO_NS::Message* proto) {
+ return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
+}
+
+template <typename Proto>
+bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
+ Proto* proto) {
+ if (proto->ParseFromString(input_file_contents)) {
+ return true;
+ }
+
+ if (ParseFromStringOverload(input_file_contents, proto)) {
+ return true;
+ }
+
+ return false;
+}
+
+// TODO(b/36075966): Clean up when dims superseded by array shape.
+void ExtendShape(Shape* shape, int new_shape_size);
+
+// TODO(b/36075966): Clean up when dims superseded by array shape.
+void UnextendShape(Shape* shape, int new_shape_size);
+
+// Checks (using CHECK) that all dimensions of 'shape' are at least 1.
+void CheckShapeDimensions(const Shape& shape);
+
+// Given two shapes with potentially different dimensionality and dimension
+// arrays d0 and d1. Without loss of generality, assume that shape0 may have
+// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
+// "agree up to broadcasting" if:
+// - When walking the d0 and d1 from back to front with indices i0, i1,
+// d0[i0] == d1[i1] or d0[i0] == 1 or d1[i1] == 1, for each dimension until
+// i1 == 0 (inclusive).
+bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
+
+// A stricter constraint than ShapesAgreeUpToBroadcasting().
+//
+// Given two shapes with potentially different dimensionality and dimension
+// arrays d0 and d1. Without loss of generality, assume that shape0 may have
+// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
+// "agree up to extending" if:
+// - When walking the d0 and d1 from back to front with indices i0, i1,
+// d0[i0] == d1[i1] for each dimension until i1 == 0 (inclusive).
+// - For the remaining indices [0..i0), d0[i0] == 1.
+bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+
+bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
+
+// If there is a wildcard dimension (-1), this may return a negative value.
+int RequiredBufferSizeForShape(const Shape& shape);
+
+bool IsConstantParameterArray(const Model& model, const string& name);
+
+void CheckNoMissingArray(const Model& model);
+void CheckInvariants(const Model& model);
+
+void CheckModelCounts(const Model& model);
+
+void FixOperatorOrdering(Model* model);
+void FixNoMissingArray(Model* model);
+void FixNoOrphanedArray(Model* model);
+
+void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
+
+template <ArrayDataType A>
+void GetQuantizationParamsFromMinMax(const ModelFlags& model_flags,
+ const MinMax& minmax,
+ QuantizationParams* quantization_params) {
+ using Integer = DataType<A>;
+ const Integer qmin = std::numeric_limits<Integer>::min();
+ const Integer qmax = std::numeric_limits<Integer>::max();
+ const double qmin_double = qmin;
+ const double qmax_double = qmax;
+ const double rmin = minmax.min;
+ const double rmax = minmax.max;
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ CHECK_LE(rmin, 0.);
+ CHECK_GE(rmax, 0.);
+ if (rmin == rmax) {
+ // Special case where the min,max range is a point. Should be {0}.
+ CHECK_EQ(rmin, 0.);
+ CHECK_EQ(rmax, 0.);
+ quantization_params->zero_point = 0;
+ quantization_params->scale = 0.;
+ return;
+ }
+
+ // General case.
+ //
+ // First determine the scale.
+ const double scale = (rmax - rmin) / (qmax_double - qmin_double);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const double zero_point_from_min = qmin_double - rmin / scale;
+ const double zero_point_from_max = qmax_double - rmax / scale;
+ const double zero_point_from_min_error =
+ std::abs(qmin_double) + std::abs(rmin / scale);
+ const double zero_point_from_max_error =
+ std::abs(qmax_double) + std::abs(rmax / scale);
+
+ const double zero_point_double =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ Integer nudged_zero_point = 0;
+ if (zero_point_double < qmin_double) {
+ nudged_zero_point = qmin;
+ } else if (zero_point_double > qmax_double) {
+ nudged_zero_point = qmax;
+ } else {
+ nudged_zero_point = static_cast<Integer>(std::round(zero_point_double));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ CHECK_GE(nudged_zero_point, qmin);
+ CHECK_LE(nudged_zero_point, qmax);
+
+ // Finally, store the result nudged quantization params.
+ quantization_params->zero_point = nudged_zero_point;
+ quantization_params->scale = scale;
+}
+
+void CheckIsReadyForQuantization(const Model& model);
+void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
+ double default_ranges_max);
+
+inline int Offset(const Shape& shape, const std::vector<int>& indices) {
+ DCHECK_EQ(shape.dimensions_count(), indices.size());
+ const int dims_count = shape.dimensions_count();
+ int offset = 0;
+ for (int i = 0; i < dims_count; i++) {
+ const int index = indices[i];
+ DCHECK(index >= 0 && index < shape.dims(i));
+ offset *= shape.dims(i);
+ offset += index;
+ }
+ return offset;
+}
+
+inline std::vector<int> ReverseOffset(const Shape& shape, int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, RequiredBufferSizeForShape(shape));
+ const int dims_count = shape.dimensions_count();
+ std::vector<int> indices(dims_count);
+ int residual = index;
+ for (int i = dims_count - 1; i >= 0; i--) {
+ indices[i] = residual % shape.dims(i);
+ residual /= shape.dims(i);
+ }
+ return indices;
+}
+
+int ElementSize(ArrayDataType data_type);
+
+void DropMinMax(Model* model, const string& array_name);
+
+bool IsAllocatableTransientArray(const Model& model, const string& array_name);
+
+void CreateOrCheckRnnStateArray(const string& name, int size, Model* model);
+
+string AvailableArrayName(const Model& model, const string& name);
+
+// Formats a shape as a string: [ dims(0), dims(1), ..., dims(num_dims-1) ].
+string ShapeToString(const Shape& shape);
+
+void PrintArrayShape(Model* model, const string& name);
+
+void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
+ std::vector<int>* out_dims);
+
+bool EstimateArithmeticOpsCount(const Model& model, int64* result);
+
+int AxesCount(AxesOrder axes_order);
+
+void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, Shape* output_shape);
+void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
+ AxesOrder output_axes_order, const Shape& output_shape,
+ const float* input_data, float* output_data);
+
+// Returns true if it may be OK for any graph transformation to ever discard
+// that array. The idea is that we can't ever discard arrays that are either
+// an input or an output of the whole graph, or that appear in RNN back-edges,
+// as that would undercut explicit flags that the user might pass.
+bool IsDiscardableArray(const Model& model, const string& array_name);
+
+void CheckFinalDataTypesSatisfied(const Model& model);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
new file mode 100644
index 0000000000..22955ce956
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <tuple>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+enum class Agreement { kBroadcast, kExtend, kBroadcastNotExtend, kNeither };
+
+// A pair of Shapes and whether they should agree up to broadcasting, extending
+// or neither.
+struct ShapePair {
+ Shape left;
+ Shape right;
+ Agreement agreement;
+};
+
+std::vector<ShapePair> CreateShapePairs() {
+ return std::vector<ShapePair>(
+ {// These agree up to broadcast.
+ {Shape({3}), Shape({3}), Agreement::kBroadcast},
+ {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast},
+ {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast},
+ {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast},
+
+ // These extend (and therefore broadcast).
+ {Shape({3}), Shape({3}), Agreement::kExtend},
+ {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kExtend},
+ {Shape({1, 1, 3}), Shape({1, 1, 3}), Agreement::kExtend},
+ {Shape({1, 1, 3}), Shape({3}), Agreement::kExtend},
+ {Shape({1, 1, 3}), Shape({1, 3}), Agreement::kExtend},
+
+ // These strictly broadcast and do not extend.
+ {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcastNotExtend},
+ {Shape({5, 4}), Shape({1}), Agreement::kBroadcastNotExtend},
+ {Shape({5, 4}), Shape({4}), Agreement::kBroadcastNotExtend},
+ {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend},
+ {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend},
+ {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend},
+
+ // These do not broadcast (and therefore also do not extend).
+ {Shape({3}), Shape({4}), Agreement::kNeither},
+ {Shape({2, 1}), Shape({8, 4, 3}), Agreement::kNeither}});
+}
+
+// ShapeTest is an empty parameterized test fixture since there is no state.
+class ShapeTest : public ::testing::TestWithParam<ShapePair> {};
+
+TEST_P(ShapeTest, Agrees) {
+ const ShapePair& param = GetParam();
+
+ switch (param.agreement) {
+ case Agreement::kBroadcast: {
+ EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ break;
+ }
+ case Agreement::kExtend: {
+ EXPECT_TRUE(ShapesAgreeUpToExtending(param.left, param.right));
+ // Anything that extends should also broadcast.
+ EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ break;
+ }
+ case Agreement::kBroadcastNotExtend: {
+ // Verify that it strictly broadcasts but does not extend.
+ EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right));
+ break;
+ }
+ case Agreement::kNeither: {
+ EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right));
+ EXPECT_FALSE(ShapesAgreeUpToBroadcasting(param.left, param.right));
+ break;
+ }
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest,
+ ::testing::ValuesIn(CreateShapePairs()));
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
new file mode 100644
index 0000000000..2d918fd4e8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -0,0 +1,60 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_binary(
+ name = "generate_op_registrations",
+ srcs = ["gen_op_registration_main.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/tools:gen_op_registration",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "gen_op_registration",
+ srcs = ["gen_op_registration.cc"],
+ hdrs = ["gen_op_registration.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+cc_test(
+ name = "gen_op_registration_test",
+ srcs = ["gen_op_registration_test.cc"],
+ data = [
+ "//tensorflow/contrib/lite:testdata/0_subgraphs.bin",
+ "//tensorflow/contrib/lite:testdata/2_subgraphs.bin",
+ "//tensorflow/contrib/lite:testdata/empty_model.bin",
+ "//tensorflow/contrib/lite:testdata/test_model.bin",
+ "//tensorflow/contrib/lite:testdata/test_model_broken.bin",
+ ],
+ deps = [
+ ":gen_op_registration",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "mutable_op_resolver",
+ srcs = ["mutable_op_resolver.cc"],
+ hdrs = ["mutable_op_resolver.h"],
+ deps = ["//tensorflow/contrib/lite:framework"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.cc b/tensorflow/contrib/lite/tools/gen_op_registration.cc
new file mode 100644
index 0000000000..57c2567e3b
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration.cc
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+#include <vector>
+
+#include "third_party/re2/re2.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+
+string NormalizeCustomOpName(const string& op) {
+ string method(op);
+ RE2::GlobalReplace(&method, "([a-z])([A-Z])", "\\1_\\2");
+ std::transform(method.begin(), method.end(), method.begin(), ::toupper);
+ return method;
+}
+
+void ReadOpsFromModel(const ::tflite::Model* model,
+ std::vector<string>* builtin_ops,
+ std::vector<string>* custom_ops) {
+ if (!model) return;
+ auto opcodes = model->operator_codes();
+ if (!opcodes) return;
+ for (const auto* opcode : *opcodes) {
+ if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
+ builtin_ops->push_back(
+ tflite::EnumNameBuiltinOperator(opcode->builtin_code()));
+ } else {
+ custom_ops->push_back(opcode->custom_code()->c_str());
+ }
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/contrib/lite/tools/gen_op_registration.h
new file mode 100644
index 0000000000..363bb2335c
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+
+// Convert the custom op name to registration name following the convention.
+// Example:
+// "custom_op" -> "CUSTOM_OP"
+// "CustomOp" -> "CUSTOM_OP"
+// Note "Register_" suffix will be added later in the tool.
+string NormalizeCustomOpName(const string& op);
+
+// Read ops from the TFLite model.
+// Enum name of builtin ops will be stored, such as "CONV_2D".
+// Custom op name will be stored as it is.
+void ReadOpsFromModel(const ::tflite::Model* model,
+ std::vector<string>* builtin_ops,
+ std::vector<string>* custom_ops);
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
new file mode 100644
index 0000000000..7b27066a21
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+using tensorflow::Flag;
+using tensorflow::Flags;
+
+namespace {
+
+void GenerateFileContent(const string& filename,
+ const std::vector<string>& builtin_ops,
+ const std::vector<string>& custom_ops) {
+ std::ofstream fout(filename);
+
+ fout << "#include "
+ "\"third_party/tensorflow/contrib/lite/model.h\"\n";
+ fout << "#include "
+ "\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n";
+ fout << "namespace tflite {\n";
+ fout << "namespace ops {\n";
+ if (!builtin_ops.empty()) {
+ fout << "namespace builtin {\n";
+ fout << "// Forward-declarations for the builtin ops.\n";
+ for (const auto& op : builtin_ops) {
+ fout << "TfLiteRegistration* Register_" << op << "();\n";
+ }
+ fout << "} // namespace builtin\n";
+ }
+
+ if (!custom_ops.empty()) {
+ fout << "namespace custom {\n";
+ fout << "// Forward-declarations for the custom ops.\n";
+ for (const auto& op : custom_ops) {
+ fout << "TfLiteRegistration* Register_"
+ << ::tflite::NormalizeCustomOpName(op) << "();\n";
+ }
+ fout << "} // namespace custom\n";
+ }
+ fout << "} // namespace ops\n";
+ fout << "} // namespace tflite\n";
+
+ fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n";
+ for (const auto& op : builtin_ops) {
+ fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op
+ << ", ::tflite::ops::builtin::Register_" << op << "());\n";
+ }
+ for (const auto& op : custom_ops) {
+ fout << " resolver->AddCustom(\"" << op
+ << "\", ::tflite::ops::custom::Register_"
+ << ::tflite::NormalizeCustomOpName(op) << "());\n";
+ }
+ fout << "}\n";
+ fout.close();
+}
+} // namespace
+
+int main(int argc, char** argv) {
+ string input_model;
+ string output_registration;
+ std::vector<tensorflow::Flag> flag_list = {
+ Flag("input_model", &input_model, "path to the tflite model"),
+ Flag("output_registration", &output_registration,
+ "filename for generated registration code"),
+ };
+ Flags::Parse(&argc, argv, flag_list);
+
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ std::vector<string> builtin_ops;
+ std::vector<string> custom_ops;
+
+ std::ifstream fin(input_model);
+ std::stringstream content;
+ content << fin.rdbuf();
+ const ::tflite::Model* model = ::tflite::GetModel(content.str().data());
+ ::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
+ GenerateFileContent(output_registration, builtin_ops, custom_ops);
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_test.cc b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc
new file mode 100644
index 0000000000..c65cffe340
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using ::testing::ElementsAreArray;
+
+namespace tflite {
+
+class GenOpRegistrationTest : public ::testing::Test {
+ protected:
+ GenOpRegistrationTest() {}
+
+ void ReadOps(const string& model_path) {
+ auto model = FlatBufferModel::BuildFromFile(model_path.data());
+ if (model) {
+ ReadOpsFromModel(model->GetModel(), &builtin_ops_, &custom_ops_);
+ }
+ }
+
+ std::vector<string> builtin_ops_;
+ std::vector<string> custom_ops_;
+};
+
+TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
+ ReadOps("/tmp/tflite_model_1234");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestModels) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/test_model.bin");
+ EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"}));
+ EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"}));
+}
+
+TEST_F(GenOpRegistrationTest, TestEmptyModels) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/empty_model.bin");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestZeroSubgraphs) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/0_subgraphs.bin");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestBrokenMmap) {
+ ReadOps("third_party/tensorflow/contrib/lite/testdata/test_model_broken.bin");
+ EXPECT_EQ(builtin_ops_.size(), 0);
+ EXPECT_EQ(custom_ops_.size(), 0);
+}
+
+TEST_F(GenOpRegistrationTest, TestNormalizeCustomOpName) {
+ std::vector<std::pair<string, string>> testcase = {
+ {"CustomOp", "CUSTOM_OP"},
+ {"a", "A"},
+ {"custom_op", "CUSTOM_OP"},
+ {"customop", "CUSTOMOP"},
+ };
+
+ for (const auto& test : testcase) {
+ EXPECT_EQ(NormalizeCustomOpName(test.first), test.second);
+ }
+}
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: FLAGS_logtostderr = true;
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
new file mode 100644
index 0000000000..8a921d7c5a
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+
+namespace tflite {
+
+TfLiteRegistration* MutableOpResolver::FindOp(
+ tflite::BuiltinOperator op) const {
+ auto it = builtins_.find(op);
+ return it != builtins_.end() ? it->second : nullptr;
+}
+
+TfLiteRegistration* MutableOpResolver::FindOp(const char* op) const {
+ auto it = custom_ops_.find(op);
+ return it != custom_ops_.end() ? it->second : nullptr;
+}
+
+void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = op;
+ builtins_.insert(std::make_pair(op, registration));
+}
+
+void MutableOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration) {
+ registration->builtin_code = BuiltinOperator_CUSTOM;
+ custom_ops_.insert(std::make_pair(std::string(name), registration));
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
new file mode 100644
index 0000000000..9546c32427
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ MutableOpResolver() {}
+ TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
+ TfLiteRegistration* FindOp(const char* op) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
+ void AddCustom(const char* name, TfLiteRegistration* registration);
+
+ private:
+ std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*> builtins_;
+ std::unordered_map<std::string, TfLiteRegistration*> custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/contrib/lite/version.h
new file mode 100644
index 0000000000..a751afabe7
--- /dev/null
+++ b/tensorflow/contrib/lite/version.h
@@ -0,0 +1,23 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
+
+// The version number of the Schema. Ideally all changes will be backward
+// compatible. If that ever changes, we must ensure that version is the first
+// entry in the new tflite root so that we can see that version is not 1.
+#define TFLITE_SCHEMA_VERSION (3)
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_