From 942545c98db7f29dbbd7b9ed765e1289656f93d5 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Tue, 7 Nov 2017 18:39:17 -0800 Subject: Fix Bazel builds for the TF Lite demo app Adds a new remote repository for the mobilenet tflite models necessary for running the TF Lite demo app. PiperOrigin-RevId: 174946867 --- third_party/tflite_mobilenet.BUILD | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 third_party/tflite_mobilenet.BUILD (limited to 'third_party') diff --git a/third_party/tflite_mobilenet.BUILD b/third_party/tflite_mobilenet.BUILD new file mode 100644 index 0000000000..75663eff48 --- /dev/null +++ b/third_party/tflite_mobilenet.BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "model_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) -- cgit v1.2.3 From 0b15439f8f0f2d4755587f4096c3ea04cb199d23 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Fri, 10 Nov 2017 10:35:35 -0800 Subject: Internal Change. PiperOrigin-RevId: 175307445 --- configure.py | 2 + tensorflow/BUILD | 19 + tensorflow/contrib/BUILD | 1 + tensorflow/contrib/__init__.py | 1 + tensorflow/contrib/cmake/tf_python.cmake | 13 + tensorflow/contrib/lite/BUILD | 280 ++ tensorflow/contrib/lite/allocation.cc | 122 + tensorflow/contrib/lite/allocation.h | 94 + tensorflow/contrib/lite/build_def.bzl | 233 ++ tensorflow/contrib/lite/builtin_op_data.h | 164 + tensorflow/contrib/lite/context.c | 92 + tensorflow/contrib/lite/context.h | 298 ++ tensorflow/contrib/lite/context_test.cc | 74 + tensorflow/contrib/lite/error_reporter.cc | 50 + tensorflow/contrib/lite/error_reporter.h | 54 + tensorflow/contrib/lite/interpreter.cc | 567 +++ tensorflow/contrib/lite/interpreter.h | 376 ++ tensorflow/contrib/lite/interpreter_test.cc | 526 +++ tensorflow/contrib/lite/java/BUILD | 164 + tensorflow/contrib/lite/java/demo/.gitignore | 9 + tensorflow/contrib/lite/java/demo/app/build.gradle | 58 + .../java/demo/app/src/main/AndroidManifest.xml | 42 + .../contrib/lite/java/demo/app/src/main/BUILD | 43 + .../lite/java/demo/app/src/main/assets/BUILD | 26 + .../lite/java/demo/app/src/main/assets/labels.txt | 1001 ++++++ .../tflitecamerademo/AutoFitTextureView.java | 72 + .../tflitecamerademo/Camera2BasicFragment.java | 708 ++++ .../android/tflitecamerademo/CameraActivity.java | 35 + .../android/tflitecamerademo/ImageClassifier.java | 184 + .../src/main/res/drawable-hdpi/ic_action_info.png | Bin 0 -> 490 bytes .../app/src/main/res/drawable-hdpi/ic_launcher.png | Bin 0 -> 3136 bytes .../demo/app/src/main/res/drawable-hdpi/tile.9.png | Bin 0 -> 116 bytes .../src/main/res/drawable-mdpi/ic_action_info.png | Bin 0 -> 320 bytes .../app/src/main/res/drawable-mdpi/ic_launcher.png | Bin 0 -> 1915 bytes .../src/main/res/drawable-xhdpi/ic_action_info.png | Bin 0 -> 611 bytes .../src/main/res/drawable-xhdpi/ic_launcher.png | Bin 0 -> 4294 bytes .../main/res/drawable-xxhdpi/ic_action_info.png | Bin 0 -> 952 bytes .../src/main/res/drawable-xxhdpi/ic_launcher.png | Bin 0 -> 7279 bytes .../res/layout-land/fragment_camera2_basic.xml | 50 + .../app/src/main/res/layout/activity_camera.xml | 22 + .../src/main/res/layout/fragment_camera2_basic.xml | 45 + .../main/res/values-sw600dp/template-dimens.xml | 24 + .../main/res/values-sw600dp/template-styles.xml | 25 + .../src/main/res/values-v11/template-styles.xml | 22 + .../app/src/main/res/values-v21/base-colors.xml | 21 + .../main/res/values-v21/base-template-styles.xml | 24 + .../demo/app/src/main/res/values/base-strings.xml | 30 + .../java/demo/app/src/main/res/values/colors.xml | 19 + .../java/demo/app/src/main/res/values/strings.xml | 24 + .../java/demo/app/src/main/res/values/styles.xml | 18 + .../app/src/main/res/values/template-dimens.xml | 32 + .../app/src/main/res/values/template-styles.xml | 42 + tensorflow/contrib/lite/java/demo/build.gradle | 23 + .../contrib/lite/java/demo/gradle.properties | 17 + .../java/demo/gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 53636 bytes .../demo/gradle/wrapper/gradle-wrapper.properties | 6 + tensorflow/contrib/lite/java/demo/gradlew | 160 + tensorflow/contrib/lite/java/demo/gradlew.bat | 90 + tensorflow/contrib/lite/java/demo/settings.gradle | 1 + .../main/java/org/tensorflow/lite/DataType.java | 76 + .../main/java/org/tensorflow/lite/Interpreter.java | 172 + .../tensorflow/lite/NativeInterpreterWrapper.java | 276 ++ .../src/main/java/org/tensorflow/lite/Tensor.java | 71 + .../java/org/tensorflow/lite/TensorFlowLite.java | 44 + .../java/org/tensorflow/lite/package-info.java | 17 + tensorflow/contrib/lite/java/src/main/native/BUILD | 70 + .../lite/java/src/main/native/builtin_ops_jni.cc | 29 + .../lite/java/src/main/native/exception_jni.cc | 66 + .../lite/java/src/main/native/exception_jni.h | 50 + .../main/native/nativeinterpreterwrapper_jni.cc | 446 +++ .../src/main/native/nativeinterpreterwrapper_jni.h | 151 + .../lite/java/src/main/native/tensor_jni.cc | 242 ++ .../contrib/lite/java/src/main/native/tensor_jni.h | 74 + .../java/src/main/native/tensorflow_lite_jni.cc | 26 + .../java/src/main/native/tensorflow_lite_jni.h | 36 + .../lite/java/src/main/native/version_script.lds | 11 + .../java/org/tensorflow/lite/DataTypeTest.java | 34 + .../java/org/tensorflow/lite/InterpreterTest.java | 221 ++ .../lite/NativeInterpreterWrapperTest.java | 406 +++ .../org/tensorflow/lite/TensorFlowLiteTest.java | 32 + .../test/java/org/tensorflow/lite/TensorTest.java | 105 + .../src/testhelper/java/org/tensorflow/lite/BUILD | 30 + .../java/org/tensorflow/lite/TestHelper.java | 35 + tensorflow/contrib/lite/kernels/BUILD | 408 +++ .../contrib/lite/kernels/activation_functor.h | 58 + tensorflow/contrib/lite/kernels/activations.cc | 389 ++ .../contrib/lite/kernels/activations_test.cc | 323 ++ tensorflow/contrib/lite/kernels/add.cc | 184 + tensorflow/contrib/lite/kernels/add_test.cc | 171 + tensorflow/contrib/lite/kernels/basic_rnn.cc | 161 + tensorflow/contrib/lite/kernels/basic_rnn_test.cc | 267 ++ tensorflow/contrib/lite/kernels/concatenation.cc | 200 ++ .../contrib/lite/kernels/concatenation_test.cc | 162 + tensorflow/contrib/lite/kernels/conv.cc | 425 +++ tensorflow/contrib/lite/kernels/conv_test.cc | 440 +++ tensorflow/contrib/lite/kernels/depthwise_conv.cc | 289 ++ .../contrib/lite/kernels/depthwise_conv_test.cc | 186 + .../contrib/lite/kernels/embedding_lookup.cc | 104 + .../lite/kernels/embedding_lookup_sparse.cc | 248 ++ .../lite/kernels/embedding_lookup_sparse_test.cc | 166 + .../contrib/lite/kernels/embedding_lookup_test.cc | 94 + tensorflow/contrib/lite/kernels/fully_connected.cc | 307 ++ .../contrib/lite/kernels/fully_connected_test.cc | 377 ++ tensorflow/contrib/lite/kernels/gemm_support.cc | 68 + tensorflow/contrib/lite/kernels/gemm_support.h | 54 + .../contrib/lite/kernels/hashtable_lookup.cc | 155 + .../contrib/lite/kernels/hashtable_lookup_test.cc | 176 + tensorflow/contrib/lite/kernels/internal/BUILD | 359 ++ tensorflow/contrib/lite/kernels/internal/common.h | 107 + .../contrib/lite/kernels/internal/compatibility.h | 78 + .../lite/kernels/internal/optimized/cpu_check.h | 65 + .../internal/optimized/depthwiseconv_float.h | 987 ++++++ .../internal/optimized/depthwiseconv_uint8.h | 1916 ++++++++++ .../optimized/eigen_spatial_convolutions.h | 231 ++ .../eigen_tensor_reduced_instantiations_google.h | 143 + .../eigen_tensor_reduced_instantiations_oss.h | 167 + .../internal/optimized/multithreaded_conv.h | 195 + .../internal/optimized/neon_tensor_utils.cc | 337 ++ .../kernels/internal/optimized/neon_tensor_utils.h | 113 + .../kernels/internal/optimized/optimized_ops.h | 3715 ++++++++++++++++++++ .../kernels/internal/optimized/tensor_utils_impl.h | 138 + .../lite/kernels/internal/quantization_util.cc | 95 + .../lite/kernels/internal/quantization_util.h | 55 + .../kernels/internal/quantization_util_test.cc | 108 + .../internal/reference/depthwiseconv_float.h | 115 + .../internal/reference/depthwiseconv_uint8.h | 138 + .../internal/reference/portable_tensor_utils.cc | 165 + .../internal/reference/portable_tensor_utils.h | 189 + .../kernels/internal/reference/reference_ops.h | 2455 +++++++++++++ tensorflow/contrib/lite/kernels/internal/round.h | 39 + tensorflow/contrib/lite/kernels/internal/tensor.h | 87 + .../contrib/lite/kernels/internal/tensor_test.cc | 55 + .../contrib/lite/kernels/internal/tensor_utils.cc | 27 + .../contrib/lite/kernels/internal/tensor_utils.h | 116 + .../lite/kernels/internal/tensor_utils_test.cc | 192 + tensorflow/contrib/lite/kernels/internal/types.h | 81 + tensorflow/contrib/lite/kernels/kernel_util.cc | 87 + tensorflow/contrib/lite/kernels/kernel_util.h | 65 + tensorflow/contrib/lite/kernels/l2norm.cc | 112 + tensorflow/contrib/lite/kernels/l2norm_test.cc | 63 + .../contrib/lite/kernels/local_response_norm.cc | 109 + .../lite/kernels/local_response_norm_test.cc | 101 + tensorflow/contrib/lite/kernels/lsh_projection.cc | 204 ++ .../contrib/lite/kernels/lsh_projection_test.cc | 123 + tensorflow/contrib/lite/kernels/lstm.cc | 515 +++ tensorflow/contrib/lite/kernels/lstm_test.cc | 1088 ++++++ tensorflow/contrib/lite/kernels/mul.cc | 167 + tensorflow/contrib/lite/kernels/mul_test.cc | 127 + tensorflow/contrib/lite/kernels/op_macros.h | 32 + .../contrib/lite/kernels/optional_tensor_test.cc | 343 ++ tensorflow/contrib/lite/kernels/padding.h | 28 + tensorflow/contrib/lite/kernels/pooling.cc | 355 ++ tensorflow/contrib/lite/kernels/pooling_test.cc | 161 + tensorflow/contrib/lite/kernels/register.cc | 109 + tensorflow/contrib/lite/kernels/register.h | 50 + tensorflow/contrib/lite/kernels/reshape.cc | 91 + tensorflow/contrib/lite/kernels/reshape_test.cc | 90 + tensorflow/contrib/lite/kernels/resize_bilinear.cc | 129 + .../contrib/lite/kernels/resize_bilinear_test.cc | 117 + tensorflow/contrib/lite/kernels/skip_gram.cc | 160 + tensorflow/contrib/lite/kernels/skip_gram_test.cc | 257 ++ tensorflow/contrib/lite/kernels/softmax_test.cc | 143 + tensorflow/contrib/lite/kernels/space_to_depth.cc | 146 + .../contrib/lite/kernels/space_to_depth_test.cc | 102 + tensorflow/contrib/lite/kernels/svdf.cc | 224 ++ tensorflow/contrib/lite/kernels/svdf_test.cc | 312 ++ tensorflow/contrib/lite/kernels/test_util.cc | 183 + tensorflow/contrib/lite/kernels/test_util.h | 202 ++ tensorflow/contrib/lite/model.cc | 673 ++++ tensorflow/contrib/lite/model.h | 165 + tensorflow/contrib/lite/model_test.cc | 258 ++ tensorflow/contrib/lite/models/smartreply/BUILD | 15 + .../lite/models/smartreply/ops/extract_feature.cc | 119 + .../models/smartreply/ops/extract_feature_test.cc | 100 + .../lite/models/smartreply/ops/normalize.cc | 105 + .../lite/models/smartreply/ops/normalize_test.cc | 90 + .../contrib/lite/models/smartreply/ops/predict.cc | 174 + .../lite/models/smartreply/ops/predict_test.cc | 183 + .../contrib/lite/models/smartreply/predictor.cc | 116 + .../contrib/lite/models/smartreply/predictor.h | 80 + .../lite/models/smartreply/predictor_test.cc | 150 + .../lite/models/speech_hotword_model_test.cc | 115 + .../lite/models/speech_speakerid_model_test.cc | 114 + .../lite/models/speech_terse_am_model_test.cc | 127 + .../contrib/lite/models/speech_tts_model_test.cc | 116 + tensorflow/contrib/lite/models/test_utils.h | 84 + tensorflow/contrib/lite/nnapi/BUILD | 25 + tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h | 1916 ++++++++++ tensorflow/contrib/lite/nnapi_delegate.cc | 386 ++ tensorflow/contrib/lite/nnapi_delegate.h | 66 + tensorflow/contrib/lite/optional_debug_tools.cc | 108 + tensorflow/contrib/lite/optional_debug_tools.h | 32 + tensorflow/contrib/lite/python/BUILD | 46 + tensorflow/contrib/lite/python/lite.py | 199 ++ tensorflow/contrib/lite/python/lite_test.py | 45 + tensorflow/contrib/lite/schema/BUILD | 82 + .../lite/schema/flatbuffer_compatibility_test.cc | 91 + tensorflow/contrib/lite/schema/schema.fbs | 346 ++ tensorflow/contrib/lite/schema/schema_v0.fbs | 247 ++ tensorflow/contrib/lite/schema/schema_v1.fbs | 295 ++ tensorflow/contrib/lite/schema/schema_v2.fbs | 303 ++ tensorflow/contrib/lite/schema/schema_v3.fbs | 326 ++ tensorflow/contrib/lite/schema/upgrade_schema.py | 341 ++ .../contrib/lite/schema/upgrade_schema_test.py | 317 ++ tensorflow/contrib/lite/simple_memory_arena.cc | 136 + tensorflow/contrib/lite/simple_memory_arena.h | 84 + .../contrib/lite/simple_memory_arena_test.cc | 91 + tensorflow/contrib/lite/string.h | 30 + tensorflow/contrib/lite/string_util.cc | 117 + tensorflow/contrib/lite/string_util.h | 91 + tensorflow/contrib/lite/string_util_test.cc | 117 + tensorflow/contrib/lite/testdata/0_subgraphs.bin | Bin 0 -> 80 bytes tensorflow/contrib/lite/testdata/2_subgraphs.bin | Bin 0 -> 172 bytes tensorflow/contrib/lite/testdata/empty_model.bin | Bin 0 -> 132 bytes tensorflow/contrib/lite/testdata/multi_add.bin | Bin 0 -> 652 bytes tensorflow/contrib/lite/testdata/multi_add.json | 46 + tensorflow/contrib/lite/testdata/no_subgraphs.bin | Bin 0 -> 80 bytes tensorflow/contrib/lite/testdata/test_model.bin | Bin 0 -> 496 bytes .../contrib/lite/testdata/test_model_broken.bin | Bin 0 -> 432 bytes .../contrib/lite/testdata/test_model_broken.json | 62 + tensorflow/contrib/lite/testdata/two_subgraphs.bin | Bin 0 -> 172 bytes tensorflow/contrib/lite/testing/BUILD | 213 ++ .../contrib/lite/testing/generate_examples.py | 1189 +++++++ .../lite/testing/generate_examples_report.py | 125 + .../lite/testing/generated_examples_zip_test.cc | 279 ++ tensorflow/contrib/lite/testing/message.cc | 96 + tensorflow/contrib/lite/testing/message.h | 82 + tensorflow/contrib/lite/testing/message_test.cc | 121 + tensorflow/contrib/lite/testing/nnapi_example.cc | 114 + tensorflow/contrib/lite/testing/parse_testdata.cc | 335 ++ tensorflow/contrib/lite/testing/parse_testdata.h | 74 + tensorflow/contrib/lite/testing/split.cc | 42 + tensorflow/contrib/lite/testing/split.h | 77 + tensorflow/contrib/lite/testing/split_test.cc | 57 + tensorflow/contrib/lite/testing/test_runner.h | 124 + .../contrib/lite/testing/test_runner_test.cc | 84 + tensorflow/contrib/lite/testing/tflite_driver.cc | 208 ++ tensorflow/contrib/lite/testing/tflite_driver.h | 62 + .../contrib/lite/testing/tflite_driver_test.cc | 61 + tensorflow/contrib/lite/testing/tokenize.cc | 95 + tensorflow/contrib/lite/testing/tokenize.h | 42 + tensorflow/contrib/lite/testing/tokenize_test.cc | 105 + tensorflow/contrib/lite/toco/BUILD | 350 ++ .../contrib/lite/toco/allocate_transient_arrays.cc | 318 ++ .../contrib/lite/toco/allocate_transient_arrays.h | 44 + tensorflow/contrib/lite/toco/args.h | 225 ++ tensorflow/contrib/lite/toco/dump_graphviz.cc | 293 ++ tensorflow/contrib/lite/toco/dump_graphviz.h | 28 + tensorflow/contrib/lite/toco/export_tensorflow.cc | 1570 +++++++++ tensorflow/contrib/lite/toco/export_tensorflow.h | 27 + tensorflow/contrib/lite/toco/format_port.h | 77 + .../convert_pure_conv_to_depthwise.cc | 98 + .../graph_transformations/create_im2col_arrays.cc | 69 + .../lite/toco/graph_transformations/dequantize.cc | 223 ++ .../toco/graph_transformations/drop_fake_quant.cc | 56 + .../graph_transformations/drop_im2col_arrays.cc | 42 + .../graph_transformations/ensure_bias_vectors.cc | 57 + .../fuse_activation_functions.cc | 98 + .../fuse_binary_into_following_affine.cc | 300 ++ .../fuse_binary_into_preceding_affine.cc | 326 ++ .../graph_transformations/graph_transformations.cc | 108 + .../graph_transformations/graph_transformations.h | 186 + .../toco/graph_transformations/hardcode_min_max.cc | 229 ++ .../identify_l2_normalization.cc | 170 + .../toco/graph_transformations/identify_l2_pool.cc | 106 + .../toco/graph_transformations/identify_lstm.cc | 396 +++ .../toco/graph_transformations/identify_relu1.cc | 103 + .../make_initial_dequantize_operator.cc | 120 + .../propagate_array_data_types.cc | 142 + .../graph_transformations/propagate_fixed_sizes.cc | 1129 ++++++ .../lite/toco/graph_transformations/quantize.cc | 467 +++ .../read_fake_quant_min_max.cc | 105 + .../remove_final_dequantize_op.cc | 59 + .../remove_tensorflow_assert.cc | 60 + .../remove_tensorflow_identity.cc | 38 + .../graph_transformations/remove_trivial_binary.cc | 113 + .../remove_trivial_concatenation.cc | 40 + .../remove_trivial_concatenation_input.cc | 68 + .../remove_trivial_passthrough.cc | 107 + .../remove_trivial_passthrough.h | 55 + .../remove_trivial_quantized_activation_func.cc | 87 + .../remove_trivial_reshape.cc | 92 + .../toco/graph_transformations/remove_unused_op.cc | 122 + .../resolve_batch_normalization.cc | 135 + .../resolve_constant_binary.cc | 247 ++ .../resolve_constant_concatenation.cc | 196 ++ .../resolve_constant_fake_quant.cc | 76 + .../resolve_constant_tensorflow_shape.cc | 62 + .../resolve_constant_unary.cc | 175 + .../resolve_mean_attributes.cc | 51 + .../resolve_pad_attributes.cc | 55 + .../graph_transformations/resolve_reorder_axes.cc | 93 + .../resolve_reshape_attributes.cc | 49 + .../resolve_slice_attributes.cc | 52 + .../resolve_strided_slice_attributes.cc | 62 + .../resolve_tensorflow_concat.cc | 86 + .../resolve_tensorflow_matmul.cc | 106 + .../resolve_tensorflow_merge.cc | 63 + .../resolve_tensorflow_squeeze.cc | 54 + .../resolve_tensorflow_switch.cc | 123 + .../resolve_tensorflow_tile.cc | 97 + .../lite/toco/graph_transformations/tests/BUILD | 31 + .../tests/resolve_constant_concatenation_test.cc | 221 ++ .../unfuse_activation_functions.cc | 73 + tensorflow/contrib/lite/toco/import_tensorflow.cc | 1508 ++++++++ tensorflow/contrib/lite/toco/import_tensorflow.h | 34 + tensorflow/contrib/lite/toco/model.h | 1372 ++++++++ .../contrib/lite/toco/model_cmdline_flags.cc | 374 ++ tensorflow/contrib/lite/toco/model_cmdline_flags.h | 43 + tensorflow/contrib/lite/toco/model_flags.proto | 119 + tensorflow/contrib/lite/toco/python/BUILD | 76 + tensorflow/contrib/lite/toco/python/toco.i | 32 + .../contrib/lite/toco/python/toco_from_protos.py | 63 + .../lite/toco/python/toco_from_protos_test.py | 96 + .../contrib/lite/toco/python/toco_python_api.cc | 85 + .../contrib/lite/toco/python/toco_python_api.h | 33 + .../contrib/lite/toco/python/toco_wrapper.py | 35 + tensorflow/contrib/lite/toco/runtime/common.h | 26 + tensorflow/contrib/lite/toco/runtime/types.h | 32 + .../lite/toco/tensorflow_graph_matching/BUILD | 102 + .../lite/toco/tensorflow_graph_matching/cluster.cc | 52 + .../lite/toco/tensorflow_graph_matching/cluster.h | 101 + .../tensorflow_graph_matching/cluster_utils.cc | 34 + .../toco/tensorflow_graph_matching/cluster_utils.h | 33 + .../tensorflow_graph_matching/resolve_cluster.cc | 151 + .../tensorflow_graph_matching/resolve_cluster.h | 63 + .../toco/tensorflow_graph_matching/resolve_svdf.cc | 285 ++ .../toco/tensorflow_graph_matching/resolve_svdf.h | 82 + .../tensorflow_graph_matching/resolve_svdf_test.cc | 212 ++ tensorflow/contrib/lite/toco/tensorflow_util.cc | 197 ++ tensorflow/contrib/lite/toco/tensorflow_util.h | 32 + tensorflow/contrib/lite/toco/tflite/BUILD | 142 + .../contrib/lite/toco/tflite/builtin_operator.h | 74 + .../contrib/lite/toco/tflite/custom_operator.h | 74 + tensorflow/contrib/lite/toco/tflite/export.cc | 322 ++ tensorflow/contrib/lite/toco/tflite/export.h | 76 + tensorflow/contrib/lite/toco/tflite/export_test.cc | 69 + tensorflow/contrib/lite/toco/tflite/import.cc | 183 + tensorflow/contrib/lite/toco/tflite/import.h | 49 + tensorflow/contrib/lite/toco/tflite/import_test.cc | 141 + tensorflow/contrib/lite/toco/tflite/operator.cc | 627 ++++ tensorflow/contrib/lite/toco/tflite/operator.h | 89 + .../contrib/lite/toco/tflite/operator_test.cc | 370 ++ .../contrib/lite/toco/tflite/simple_operator.h | 50 + tensorflow/contrib/lite/toco/tflite/types.cc | 165 + tensorflow/contrib/lite/toco/tflite/types.h | 58 + tensorflow/contrib/lite/toco/tflite/types_test.cc | 191 + tensorflow/contrib/lite/toco/toco.cc | 119 + tensorflow/contrib/lite/toco/toco_cmdline_flags.cc | 206 ++ tensorflow/contrib/lite/toco/toco_cmdline_flags.h | 35 + tensorflow/contrib/lite/toco/toco_flags.proto | 126 + .../lite/toco/toco_graphviz_dump_options.cc | 22 + .../contrib/lite/toco/toco_graphviz_dump_options.h | 34 + tensorflow/contrib/lite/toco/toco_port.cc | 227 ++ tensorflow/contrib/lite/toco/toco_port.h | 80 + tensorflow/contrib/lite/toco/toco_port_test.cc | 58 + tensorflow/contrib/lite/toco/toco_tooling.cc | 277 ++ tensorflow/contrib/lite/toco/toco_tooling.h | 50 + tensorflow/contrib/lite/toco/toco_types.h | 45 + tensorflow/contrib/lite/toco/tooling_util.cc | 1552 ++++++++ tensorflow/contrib/lite/toco/tooling_util.h | 292 ++ tensorflow/contrib/lite/toco/tooling_util_test.cc | 96 + tensorflow/contrib/lite/tools/BUILD | 60 + .../contrib/lite/tools/gen_op_registration.cc | 46 + .../contrib/lite/tools/gen_op_registration.h | 38 + .../contrib/lite/tools/gen_op_registration_main.cc | 98 + .../contrib/lite/tools/gen_op_registration_test.cc | 87 + .../contrib/lite/tools/mutable_op_resolver.cc | 43 + .../contrib/lite/tools/mutable_op_resolver.h | 45 + tensorflow/contrib/lite/version.h | 23 + tensorflow/tools/ci_build/ci_sanity.sh | 7 +- .../tools/ci_build/linux/cpu/run_py3_contrib.sh | 33 +- tensorflow/tools/ci_build/osx/cpu/run_contrib.sh | 2 +- tensorflow/tools/pip_package/BUILD | 3 + tensorflow/tools/pip_package/MANIFEST.in | 1 + tensorflow/tools/pip_package/build_pip_package.sh | 3 + tensorflow/tools/pip_package/setup.py | 3 +- third_party/flatbuffers/flatbuffers.BUILD | 4 + 378 files changed, 66985 insertions(+), 4 deletions(-) create mode 100644 tensorflow/contrib/lite/BUILD create mode 100644 tensorflow/contrib/lite/allocation.cc create mode 100644 tensorflow/contrib/lite/allocation.h create mode 100644 tensorflow/contrib/lite/build_def.bzl create mode 100644 tensorflow/contrib/lite/builtin_op_data.h create mode 100644 tensorflow/contrib/lite/context.c create mode 100644 tensorflow/contrib/lite/context.h create mode 100644 tensorflow/contrib/lite/context_test.cc create mode 100644 tensorflow/contrib/lite/error_reporter.cc create mode 100644 tensorflow/contrib/lite/error_reporter.h create mode 100644 tensorflow/contrib/lite/interpreter.cc create mode 100644 tensorflow/contrib/lite/interpreter.h create mode 100644 tensorflow/contrib/lite/interpreter_test.cc create mode 100644 tensorflow/contrib/lite/java/BUILD create mode 100644 tensorflow/contrib/lite/java/demo/.gitignore create mode 100644 tensorflow/contrib/lite/java/demo/app/build.gradle create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/BUILD create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml create mode 100644 tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml create mode 100644 tensorflow/contrib/lite/java/demo/build.gradle create mode 100644 tensorflow/contrib/lite/java/demo/gradle.properties create mode 100644 tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar create mode 100644 tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties create mode 100755 tensorflow/contrib/lite/java/demo/gradlew create mode 100644 tensorflow/contrib/lite/java/demo/gradlew.bat create mode 100644 tensorflow/contrib/lite/java/demo/settings.gradle create mode 100644 tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java create mode 100644 tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java create mode 100644 tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java create mode 100644 tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java create mode 100644 tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java create mode 100644 tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java create mode 100644 tensorflow/contrib/lite/java/src/main/native/BUILD create mode 100644 tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc create mode 100644 tensorflow/contrib/lite/java/src/main/native/exception_jni.cc create mode 100644 tensorflow/contrib/lite/java/src/main/native/exception_jni.h create mode 100644 tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc create mode 100644 tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h create mode 100644 tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc create mode 100644 tensorflow/contrib/lite/java/src/main/native/tensor_jni.h create mode 100644 tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc create mode 100644 tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h create mode 100644 tensorflow/contrib/lite/java/src/main/native/version_script.lds create mode 100644 tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java create mode 100644 tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java create mode 100644 tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java create mode 100644 tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java create mode 100644 tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java create mode 100644 tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD create mode 100644 tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java create mode 100644 tensorflow/contrib/lite/kernels/BUILD create mode 100644 tensorflow/contrib/lite/kernels/activation_functor.h create mode 100644 tensorflow/contrib/lite/kernels/activations.cc create mode 100644 tensorflow/contrib/lite/kernels/activations_test.cc create mode 100644 tensorflow/contrib/lite/kernels/add.cc create mode 100644 tensorflow/contrib/lite/kernels/add_test.cc create mode 100644 tensorflow/contrib/lite/kernels/basic_rnn.cc create mode 100644 tensorflow/contrib/lite/kernels/basic_rnn_test.cc create mode 100644 tensorflow/contrib/lite/kernels/concatenation.cc create mode 100644 tensorflow/contrib/lite/kernels/concatenation_test.cc create mode 100644 tensorflow/contrib/lite/kernels/conv.cc create mode 100644 tensorflow/contrib/lite/kernels/conv_test.cc create mode 100644 tensorflow/contrib/lite/kernels/depthwise_conv.cc create mode 100644 tensorflow/contrib/lite/kernels/depthwise_conv_test.cc create mode 100644 tensorflow/contrib/lite/kernels/embedding_lookup.cc create mode 100644 tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc create mode 100644 tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc create mode 100644 tensorflow/contrib/lite/kernels/embedding_lookup_test.cc create mode 100644 tensorflow/contrib/lite/kernels/fully_connected.cc create mode 100644 tensorflow/contrib/lite/kernels/fully_connected_test.cc create mode 100644 tensorflow/contrib/lite/kernels/gemm_support.cc create mode 100644 tensorflow/contrib/lite/kernels/gemm_support.h create mode 100644 tensorflow/contrib/lite/kernels/hashtable_lookup.cc create mode 100644 tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/BUILD create mode 100644 tensorflow/contrib/lite/kernels/internal/common.h create mode 100644 tensorflow/contrib/lite/kernels/internal/compatibility.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h create mode 100644 tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h create mode 100644 tensorflow/contrib/lite/kernels/internal/quantization_util.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/quantization_util.h create mode 100644 tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h create mode 100644 tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h create mode 100644 tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h create mode 100644 tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h create mode 100644 tensorflow/contrib/lite/kernels/internal/round.h create mode 100644 tensorflow/contrib/lite/kernels/internal/tensor.h create mode 100644 tensorflow/contrib/lite/kernels/internal/tensor_test.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/tensor_utils.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/tensor_utils.h create mode 100644 tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/types.h create mode 100644 tensorflow/contrib/lite/kernels/kernel_util.cc create mode 100644 tensorflow/contrib/lite/kernels/kernel_util.h create mode 100644 tensorflow/contrib/lite/kernels/l2norm.cc create mode 100644 tensorflow/contrib/lite/kernels/l2norm_test.cc create mode 100644 tensorflow/contrib/lite/kernels/local_response_norm.cc create mode 100644 tensorflow/contrib/lite/kernels/local_response_norm_test.cc create mode 100644 tensorflow/contrib/lite/kernels/lsh_projection.cc create mode 100644 tensorflow/contrib/lite/kernels/lsh_projection_test.cc create mode 100644 tensorflow/contrib/lite/kernels/lstm.cc create mode 100644 tensorflow/contrib/lite/kernels/lstm_test.cc create mode 100644 tensorflow/contrib/lite/kernels/mul.cc create mode 100644 tensorflow/contrib/lite/kernels/mul_test.cc create mode 100644 tensorflow/contrib/lite/kernels/op_macros.h create mode 100644 tensorflow/contrib/lite/kernels/optional_tensor_test.cc create mode 100644 tensorflow/contrib/lite/kernels/padding.h create mode 100644 tensorflow/contrib/lite/kernels/pooling.cc create mode 100644 tensorflow/contrib/lite/kernels/pooling_test.cc create mode 100644 tensorflow/contrib/lite/kernels/register.cc create mode 100644 tensorflow/contrib/lite/kernels/register.h create mode 100644 tensorflow/contrib/lite/kernels/reshape.cc create mode 100644 tensorflow/contrib/lite/kernels/reshape_test.cc create mode 100644 tensorflow/contrib/lite/kernels/resize_bilinear.cc create mode 100644 tensorflow/contrib/lite/kernels/resize_bilinear_test.cc create mode 100644 tensorflow/contrib/lite/kernels/skip_gram.cc create mode 100644 tensorflow/contrib/lite/kernels/skip_gram_test.cc create mode 100644 tensorflow/contrib/lite/kernels/softmax_test.cc create mode 100644 tensorflow/contrib/lite/kernels/space_to_depth.cc create mode 100644 tensorflow/contrib/lite/kernels/space_to_depth_test.cc create mode 100644 tensorflow/contrib/lite/kernels/svdf.cc create mode 100644 tensorflow/contrib/lite/kernels/svdf_test.cc create mode 100644 tensorflow/contrib/lite/kernels/test_util.cc create mode 100644 tensorflow/contrib/lite/kernels/test_util.h create mode 100644 tensorflow/contrib/lite/model.cc create mode 100644 tensorflow/contrib/lite/model.h create mode 100644 tensorflow/contrib/lite/model_test.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/BUILD create mode 100644 tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/ops/normalize.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/ops/predict.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/predictor.cc create mode 100644 tensorflow/contrib/lite/models/smartreply/predictor.h create mode 100644 tensorflow/contrib/lite/models/smartreply/predictor_test.cc create mode 100644 tensorflow/contrib/lite/models/speech_hotword_model_test.cc create mode 100644 tensorflow/contrib/lite/models/speech_speakerid_model_test.cc create mode 100644 tensorflow/contrib/lite/models/speech_terse_am_model_test.cc create mode 100644 tensorflow/contrib/lite/models/speech_tts_model_test.cc create mode 100644 tensorflow/contrib/lite/models/test_utils.h create mode 100644 tensorflow/contrib/lite/nnapi/BUILD create mode 100644 tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h create mode 100644 tensorflow/contrib/lite/nnapi_delegate.cc create mode 100644 tensorflow/contrib/lite/nnapi_delegate.h create mode 100644 tensorflow/contrib/lite/optional_debug_tools.cc create mode 100644 tensorflow/contrib/lite/optional_debug_tools.h create mode 100644 tensorflow/contrib/lite/python/BUILD create mode 100644 tensorflow/contrib/lite/python/lite.py create mode 100644 tensorflow/contrib/lite/python/lite_test.py create mode 100644 tensorflow/contrib/lite/schema/BUILD create mode 100644 tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc create mode 100644 tensorflow/contrib/lite/schema/schema.fbs create mode 100644 tensorflow/contrib/lite/schema/schema_v0.fbs create mode 100644 tensorflow/contrib/lite/schema/schema_v1.fbs create mode 100644 tensorflow/contrib/lite/schema/schema_v2.fbs create mode 100644 tensorflow/contrib/lite/schema/schema_v3.fbs create mode 100644 tensorflow/contrib/lite/schema/upgrade_schema.py create mode 100644 tensorflow/contrib/lite/schema/upgrade_schema_test.py create mode 100644 tensorflow/contrib/lite/simple_memory_arena.cc create mode 100644 tensorflow/contrib/lite/simple_memory_arena.h create mode 100644 tensorflow/contrib/lite/simple_memory_arena_test.cc create mode 100644 tensorflow/contrib/lite/string.h create mode 100644 tensorflow/contrib/lite/string_util.cc create mode 100644 tensorflow/contrib/lite/string_util.h create mode 100644 tensorflow/contrib/lite/string_util_test.cc create mode 100644 tensorflow/contrib/lite/testdata/0_subgraphs.bin create mode 100644 tensorflow/contrib/lite/testdata/2_subgraphs.bin create mode 100644 tensorflow/contrib/lite/testdata/empty_model.bin create mode 100644 tensorflow/contrib/lite/testdata/multi_add.bin create mode 100644 tensorflow/contrib/lite/testdata/multi_add.json create mode 100644 tensorflow/contrib/lite/testdata/no_subgraphs.bin create mode 100644 tensorflow/contrib/lite/testdata/test_model.bin create mode 100644 tensorflow/contrib/lite/testdata/test_model_broken.bin create mode 100644 tensorflow/contrib/lite/testdata/test_model_broken.json create mode 100644 tensorflow/contrib/lite/testdata/two_subgraphs.bin create mode 100644 tensorflow/contrib/lite/testing/BUILD create mode 100644 tensorflow/contrib/lite/testing/generate_examples.py create mode 100644 tensorflow/contrib/lite/testing/generate_examples_report.py create mode 100644 tensorflow/contrib/lite/testing/generated_examples_zip_test.cc create mode 100644 tensorflow/contrib/lite/testing/message.cc create mode 100644 tensorflow/contrib/lite/testing/message.h create mode 100644 tensorflow/contrib/lite/testing/message_test.cc create mode 100644 tensorflow/contrib/lite/testing/nnapi_example.cc create mode 100644 tensorflow/contrib/lite/testing/parse_testdata.cc create mode 100644 tensorflow/contrib/lite/testing/parse_testdata.h create mode 100644 tensorflow/contrib/lite/testing/split.cc create mode 100644 tensorflow/contrib/lite/testing/split.h create mode 100644 tensorflow/contrib/lite/testing/split_test.cc create mode 100644 tensorflow/contrib/lite/testing/test_runner.h create mode 100644 tensorflow/contrib/lite/testing/test_runner_test.cc create mode 100644 tensorflow/contrib/lite/testing/tflite_driver.cc create mode 100644 tensorflow/contrib/lite/testing/tflite_driver.h create mode 100644 tensorflow/contrib/lite/testing/tflite_driver_test.cc create mode 100644 tensorflow/contrib/lite/testing/tokenize.cc create mode 100644 tensorflow/contrib/lite/testing/tokenize.h create mode 100644 tensorflow/contrib/lite/testing/tokenize_test.cc create mode 100644 tensorflow/contrib/lite/toco/BUILD create mode 100644 tensorflow/contrib/lite/toco/allocate_transient_arrays.cc create mode 100644 tensorflow/contrib/lite/toco/allocate_transient_arrays.h create mode 100644 tensorflow/contrib/lite/toco/args.h create mode 100644 tensorflow/contrib/lite/toco/dump_graphviz.cc create mode 100644 tensorflow/contrib/lite/toco/dump_graphviz.h create mode 100644 tensorflow/contrib/lite/toco/export_tensorflow.cc create mode 100644 tensorflow/contrib/lite/toco/export_tensorflow.h create mode 100644 tensorflow/contrib/lite/toco/format_port.h create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/quantize.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc create mode 100644 tensorflow/contrib/lite/toco/import_tensorflow.cc create mode 100644 tensorflow/contrib/lite/toco/import_tensorflow.h create mode 100644 tensorflow/contrib/lite/toco/model.h create mode 100644 tensorflow/contrib/lite/toco/model_cmdline_flags.cc create mode 100644 tensorflow/contrib/lite/toco/model_cmdline_flags.h create mode 100644 tensorflow/contrib/lite/toco/model_flags.proto create mode 100644 tensorflow/contrib/lite/toco/python/BUILD create mode 100644 tensorflow/contrib/lite/toco/python/toco.i create mode 100644 tensorflow/contrib/lite/toco/python/toco_from_protos.py create mode 100644 tensorflow/contrib/lite/toco/python/toco_from_protos_test.py create mode 100644 tensorflow/contrib/lite/toco/python/toco_python_api.cc create mode 100644 tensorflow/contrib/lite/toco/python/toco_python_api.h create mode 100644 tensorflow/contrib/lite/toco/python/toco_wrapper.py create mode 100644 tensorflow/contrib/lite/toco/runtime/common.h create mode 100644 tensorflow/contrib/lite/toco/runtime/types.h create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h create mode 100644 tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc create mode 100644 tensorflow/contrib/lite/toco/tensorflow_util.cc create mode 100644 tensorflow/contrib/lite/toco/tensorflow_util.h create mode 100644 tensorflow/contrib/lite/toco/tflite/BUILD create mode 100644 tensorflow/contrib/lite/toco/tflite/builtin_operator.h create mode 100644 tensorflow/contrib/lite/toco/tflite/custom_operator.h create mode 100644 tensorflow/contrib/lite/toco/tflite/export.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/export.h create mode 100644 tensorflow/contrib/lite/toco/tflite/export_test.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/import.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/import.h create mode 100644 tensorflow/contrib/lite/toco/tflite/import_test.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/operator.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/operator.h create mode 100644 tensorflow/contrib/lite/toco/tflite/operator_test.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/simple_operator.h create mode 100644 tensorflow/contrib/lite/toco/tflite/types.cc create mode 100644 tensorflow/contrib/lite/toco/tflite/types.h create mode 100644 tensorflow/contrib/lite/toco/tflite/types_test.cc create mode 100644 tensorflow/contrib/lite/toco/toco.cc create mode 100644 tensorflow/contrib/lite/toco/toco_cmdline_flags.cc create mode 100644 tensorflow/contrib/lite/toco/toco_cmdline_flags.h create mode 100644 tensorflow/contrib/lite/toco/toco_flags.proto create mode 100644 tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc create mode 100644 tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h create mode 100644 tensorflow/contrib/lite/toco/toco_port.cc create mode 100644 tensorflow/contrib/lite/toco/toco_port.h create mode 100644 tensorflow/contrib/lite/toco/toco_port_test.cc create mode 100644 tensorflow/contrib/lite/toco/toco_tooling.cc create mode 100644 tensorflow/contrib/lite/toco/toco_tooling.h create mode 100644 tensorflow/contrib/lite/toco/toco_types.h create mode 100644 tensorflow/contrib/lite/toco/tooling_util.cc create mode 100644 tensorflow/contrib/lite/toco/tooling_util.h create mode 100644 tensorflow/contrib/lite/toco/tooling_util_test.cc create mode 100644 tensorflow/contrib/lite/tools/BUILD create mode 100644 tensorflow/contrib/lite/tools/gen_op_registration.cc create mode 100644 tensorflow/contrib/lite/tools/gen_op_registration.h create mode 100644 tensorflow/contrib/lite/tools/gen_op_registration_main.cc create mode 100644 tensorflow/contrib/lite/tools/gen_op_registration_test.cc create mode 100644 tensorflow/contrib/lite/tools/mutable_op_resolver.cc create mode 100644 tensorflow/contrib/lite/tools/mutable_op_resolver.h create mode 100644 tensorflow/contrib/lite/version.h (limited to 'third_party') diff --git a/configure.py b/configure.py index e98367ef9f..3c0df9475d 100644 --- a/configure.py +++ b/configure.py @@ -492,6 +492,8 @@ def set_cc_opt_flags(environ_cp): write_to_bazelrc( 'build:opt --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) write_to_bazelrc('build:opt --define with_default_optimizations=true') + write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') + write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 8cb7edcc50..82a57ac185 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -475,6 +475,25 @@ filegroup( "//tensorflow/contrib/learn/python/learn/datasets:all_files", "//tensorflow/contrib/linalg:all_files", "//tensorflow/contrib/linear_optimizer:all_files", + "//tensorflow/contrib/lite:all_files", + "//tensorflow/contrib/lite/java:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", + "//tensorflow/contrib/lite/java/src/main/native:all_files", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", + "//tensorflow/contrib/lite/kernels:all_files", + "//tensorflow/contrib/lite/kernels/internal:all_files", + "//tensorflow/contrib/lite/models/smartreply:all_files", + "//tensorflow/contrib/lite/nnapi:all_files", + "//tensorflow/contrib/lite/python:all_files", + "//tensorflow/contrib/lite/schema:all_files", + "//tensorflow/contrib/lite/testing:all_files", + "//tensorflow/contrib/lite/toco:all_files", + "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", + "//tensorflow/contrib/lite/toco/python:all_files", + "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", + "//tensorflow/contrib/lite/toco/tflite:all_files", + "//tensorflow/contrib/lite/tools:all_files", "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/makefile:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 3d53cbba56..b7ade95115 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -51,6 +51,7 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", + "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 3068e9ed8f..1eda1abfcf 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -79,6 +79,7 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager +from tensorflow.contrib.lite.python import lite from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 7636e9ba6e..9517aa4963 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -499,6 +499,19 @@ add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") add_python_module("tensorflow/contrib/linear_optimizer/python") add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") add_python_module("tensorflow/contrib/linear_optimizer/python/ops") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py") +add_custom_command( + TARGET tf_python_copy_scripts_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) add_python_module("tensorflow/contrib/lookup") add_python_module("tensorflow/contrib/losses") add_python_module("tensorflow/contrib/losses/python") diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD new file mode 100644 index 0000000000..c58f77cb11 --- /dev/null +++ b/tensorflow/contrib/lite/BUILD @@ -0,0 +1,280 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + +exports_files(glob([ + "testdata/*.bin", + "models/testdata/*", +])) + +config_setting( + name = "mips", + values = { + "cpu": "mips", + }, +) + +config_setting( + name = "mips64", + values = { + "cpu": "mips64", + }, +) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "schema_fbs_version", + hdrs = ["version.h"], +) + +# Main library. No ops are included here. +# TODO(aselle): Resolve problems preventing C99 usage. +cc_library( + name = "context", + srcs = ["context.c"], + hdrs = ["context.h"], +) + +cc_library( + name = "builtin_op_data", + hdrs = [ + "builtin_op_data.h", + ], +) + +cc_library( + name = "string", + hdrs = [ + "string.h", + ], + deps = [ + "//tensorflow/core:lib_platform", + ], +) + +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [ + "allocation.cc", + "error_reporter.cc", + "interpreter.cc", + "model.cc", + "nnapi_delegate.cc", + "optional_debug_tools.cc", + "simple_memory_arena.cc", + ], + hdrs = [ + "allocation.h", + "context.h", + "error_reporter.h", + "interpreter.h", + "model.h", + "nnapi_delegate.h", + "optional_debug_tools.h", + "simple_memory_arena.h", + ], + copts = tflite_copts(), + deps = [ + ":builtin_op_data", + ":context", + ":schema_fbs_version", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:lib_platform", + ], +) + +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = [ + ":framework", + ":string", + ], +) + +cc_test( + name = "string_util_test", + size = "small", + srcs = ["string_util_test.cc"], + deps = [ + ":framework", + ":string_util", + "@com_google_googletest//:gtest", + ], +) + +# Test main interpreter +cc_test( + name = "interpreter_test", + size = "small", + srcs = ["interpreter_test.cc"], + deps = [ + ":framework", + ":string_util", + "@com_google_googletest//:gtest", + ], +) + +# Test arena allocator +cc_test( + name = "simple_memory_arena_test", + size = "small", + srcs = ["simple_memory_arena_test.cc"], + deps = [ + ":framework", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework. +cc_test( + name = "model_test", + size = "small", + srcs = ["model_test.cc"], + data = [ + "testdata/0_subgraphs.bin", + "testdata/2_subgraphs.bin", + "testdata/empty_model.bin", + "testdata/test_model.bin", + "testdata/test_model_broken.bin", + ], + deps = [ + ":framework", + "@com_google_googletest//:gtest", + ], +) + +# Test the C extension API code. +cc_test( + name = "context_test", + size = "small", + srcs = ["context_test.cc"], + deps = [ + ":framework", + "@com_google_googletest//:gtest", + ], +) + +# Test the serialization of a model with optional tensors. + +# Model tests + +cc_library( + name = "models_test_utils", + testonly = 1, + hdrs = ["models/test_utils.h"], + deps = select({ + "//tensorflow:android": [], + "//conditions:default": [ + #"//file/base:path", + "//tensorflow/core:test", + ], + }), +) + +cc_test( + name = "speech_hotword_model_test", + size = "small", + srcs = ["models/speech_hotword_model_test.cc"], + data = [ + "models/testdata/speech_hotword_model_in.csv", + "models/testdata/speech_hotword_model_out_rank1.csv", + "models/testdata/speech_hotword_model_out_rank2.csv", + "models/testdata/speech_hotword_model_rank1.tflite", + "models/testdata/speech_hotword_model_rank2.tflite", + ], + deps = [ + ":framework", + ":models_test_utils", + #"//file/base:path", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "@com_google_googletest//:gtest_main", + ], +) + +gen_selected_ops( + name = "speech_speakerid_ops", + model = "models/testdata/speech_speakerid_model.tflite", +) + +cc_test( + name = "speech_speakerid_model_test", + size = "small", + srcs = [ + "models/speech_speakerid_model_test.cc", + ":speech_speakerid_ops", + ], + data = [ + "models/testdata/speech_speakerid_model.tflite", + "models/testdata/speech_speakerid_model_in.csv", + "models/testdata/speech_speakerid_model_out.csv", + ], + deps = [ + ":framework", + ":models_test_utils", + #"//file/base:path", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "speech_terse_am_model_test", + size = "small", + srcs = ["models/speech_terse_am_model_test.cc"], + data = [ + "models/testdata/speech_terse_am_model.tflite", + "models/testdata/speech_terse_am_model_in.csv", + "models/testdata/speech_terse_am_model_out.csv", + ], + deps = [ + ":framework", + ":models_test_utils", + #"//file/base:path", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "speech_tts_model_test", + size = "small", + srcs = ["models/speech_tts_model_test.cc"], + data = [ + "models/testdata/speech_tts_model.tflite", + "models/testdata/speech_tts_model_in.csv", + "models/testdata/speech_tts_model_out.csv", + ], + deps = [ + ":framework", + ":models_test_utils", + #"//file/base:path", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc new file mode 100644 index 0000000000..4b322e027d --- /dev/null +++ b/tensorflow/contrib/lite/allocation.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +FileCopyAllocation::FileCopyAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + // Obtain the file size, using an alternative method that is does not + // require fstat for more compatibility. + std::unique_ptr file(fopen(filename, "rb"), fclose); + if (!file) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + // TODO(ahentz): Why did you think using fseek here was better for finding + // the size? + struct stat sb; + if (fstat(fileno(file.get()), &sb) != 0) { + error_reporter_->Report("Failed to get file size of '%s'.", filename); + return; + } + buffer_size_bytes_ = sb.st_size; + std::unique_ptr buffer(new char[buffer_size_bytes_]); + if (!buffer) { + error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.", + filename); + return; + } + size_t bytes_read = + fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get()); + if (bytes_read != buffer_size_bytes_) { + error_reporter_->Report("Read of '%s' failed (too few bytes read).", + filename); + return; + } + copied_buffer_ = std::move(buffer); +} + +FileCopyAllocation::~FileCopyAllocation() {} + +const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } + +size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } + +bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } + +MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + buffer_ = ptr; + buffer_size_bytes_ = num_bytes; +} + +MemoryAllocation::~MemoryAllocation() {} + +const void* MemoryAllocation::base() const { return buffer_; } + +size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; } + +bool MemoryAllocation::valid() const { return buffer_ != nullptr; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h new file mode 100644 index 0000000000..ee8a7ccd0b --- /dev/null +++ b/tensorflow/contrib/lite/allocation.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {} + virtual ~Allocation() {} + + // Base pointer of this allocation + virtual const void* base() const = 0; + // Size in bytes of the allocation + virtual size_t bytes() const = 0; + // Whether the allocation is valid + virtual bool valid() const = 0; + + protected: + ErrorReporter* error_reporter_; +}; + +class MMAPAllocation : public Allocation { + public: + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~MMAPAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class FileCopyAllocation : public Allocation { + public: + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~FileCopyAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + // Data required for mmap. + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + // Allocates memory with the pointer and the number of bytes of the memory. + // The pointer has to remain alive and unchanged until the destructor is + // called. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + virtual ~MemoryAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl new file mode 100644 index 0000000000..e3c9cdd99b --- /dev/null +++ b/tensorflow/contrib/lite/build_def.bzl @@ -0,0 +1,233 @@ +"""Generate Flatbuffer binary from json.""" + +def tflite_copts(): + """Defines compile time flags.""" + copts = [ + "-DFARMHASH_NO_CXX_STRING", + ] + select({ + "//tensorflow:android_arm64": [ + "-std=c++11", + "-O3", + ], + "//tensorflow:android_arm": [ + "-mfpu=neon", + "-mfloat-abi=softfp", + "-std=c++11", + "-O3", + ], + "//tensorflow:android_x86": [ + "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", + ], + "//tensorflow:ios_x86_64": [ + "-msse4.1", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_default_optimizations": [], + "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], + }) + + return copts + +LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" + +def tflite_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_jni_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary with JNI. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_linkopts(): + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) + +def tflite_jni_linkopts(): + """Defines linker flags to reduce size of TFLite binary with JNI.""" + return tflite_jni_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) + + +def tflite_jni_binary(name, + copts=tflite_copts(), + linkopts=tflite_jni_linkopts(), + linkscript=LINKER_SCRIPT, + linkshared=1, + linkstatic=1, + deps=[]): + """Builds a jni binary for TFLite.""" + linkopts = linkopts + [ + "-Wl,--version-script", # Export only jni functions & classes. + linkscript, + ] + native.cc_binary( + name=name, + copts=copts, + linkshared=linkshared, + linkstatic=linkstatic, + deps= deps + [linkscript], + linkopts=linkopts) + +def tf_to_tflite(name, src, options, out): + """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TOCO. + out: name of the output flatbuffer file. + """ + + toco = "//tensorflow/contrib/lite/toco:toco" + native.genrule( + name = name, + srcs=[src, options], + outs=[out], + cmd = ("$(location %s) " + + " --input_file=$(location %s) " + + " --output_file=$(location %s) " + + " --input_format=TENSORFLOW_GRAPHDEF" + + " --output_format=TFLITE" + + " `cat $(location %s)`") + % (toco, src, out, options), + tools= [toco], + ) + +def tflite_to_json(name, src, out): + """Convert a TF Lite flatbuffer to JSON. + + Args: + name: Name of rule. + src: name of the input flatbuffer file. + out: name of the output JSON file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema.fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + + "$(location %s) --raw-binary --strict-json -t" + + " -o /tmp $(location %s) -- $${TMP}.bin &&" + + "cp $${TMP}.json $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def json_to_tflite(name, src, out): + """Convert a JSON file to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input JSON file. + out: name of the output flatbuffer file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema_fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + + "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + + " -o /tmp $(location %s) $${TMP}.json &&" + + "cp $${TMP}.bin $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def gen_zipped_test_files(name, files): + """Generate a zip file of tests by using :generate_examples. + + Args: + name: Name of output. We will produce "`name`_files" as a target. + files: A list of zip file basenames. + """ + toco = "//tensorflow/contrib/lite/toco:toco" + out_files = [] + for f in files: + out_file = name + "/" + f + out_files.append(out_file) + native.genrule( + name = name + "_" + f + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + f + + " $(@D) zipped"), + outs = [out_file], + tools = [ + ":generate_examples", + toco, + ], + ) + + native.filegroup( + name = name, + srcs = out_files, + ) + +def gen_selected_ops(name, model): + """Generate the library that includes only used ops. + + Args: + name: Name of the generated library. + model: TFLite model to interpret. + """ + out = name + "_registration.cc" + tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + native.genrule( + name = name, + srcs = [model], + outs = [out], + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)") + % (tool, model, out), + tools = [tool], + ) diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h new file mode 100644 index 0000000000..93072bf90b --- /dev/null +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; + +typedef struct { float beta; } TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef struct { + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; +} TfLiteLSTMParams; + +typedef struct { + int new_height; + int new_width; +} TfLiteResizeBilinearParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c new file mode 100644 index 0000000000..c09e838c5c --- /dev/null +++ b/tensorflow/contrib/lite/context.c @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include +#include + +TfLiteIntArray* TfLiteIntArrayCreate(int size) { + TfLiteIntArray* ret = + (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + ret->size = size; + return ret; +} + +void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) { + printf("%s: length=%d [", s, a->size); + if (a->size) printf("%d", a->data[0]); + int i = 1; + for (; i < a->size; i++) { + printf(" %d", a->data[i]); + } + printf("]\n"); +} + +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { + if (a == b) return 1; + if (a == NULL || b == NULL) return 0; + if (a->size != b->size) return 0; + int i = 0; + for (; i < a->size; i++) + if (a->data[i] != b->data[i]) return 0; + return 1; +} + +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { + if (!src) return NULL; + TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size); + if (ret) { + memcpy(ret->data, src->data, src->size * sizeof(int)); + } + return ret; +} + +void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } + +void TfLiteTensorFree(TfLiteTensor* t) { + if (t->allocation_type == kTfLiteDynamic && t->data.raw) { + free(t->data.raw); + } + if (t->dims) TfLiteIntArrayFree(t->dims); + t->data.raw = NULL; + t->dims = NULL; +} + +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor) { + TfLiteTensorFree(tensor); + tensor->type = type; + tensor->name = name; + tensor->dims = dims; + tensor->params = quantization; + tensor->data.raw = buffer; + tensor->bytes = size; + tensor->allocation_type = allocation_type; + tensor->allocation = allocation; +} + +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + return; + } + if (!tensor->data.raw) { + tensor->data.raw = malloc(num_bytes); + } else if (num_bytes > tensor->bytes) { + tensor->data.raw = realloc(tensor->data.raw, num_bytes); + } + tensor->bytes = num_bytes; +} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h new file mode 100644 index 0000000000..41257a53b1 --- /dev/null +++ b/tensorflow/contrib/lite/context.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of points that points to memory for a given tensor. +typedef union { + int* i32; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; +} TfLiteTensor; + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +typedef struct TfLiteContext { + // Number of tensors in the context. + int tensors_size; + // An tensor of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // TODO(ahentz): we should create a more general mechanism for this sort of + // library-global objects. + void* gemm_context; +} TfLiteContext; + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. + void* builtin_data; +} TfLiteNode; + +typedef struct { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. Note, it is the responsibility of the registration binder to + // set this properly. + int32_t builtin_code; +} TfLiteRegistration; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc new file mode 100644 index 0000000000..d0a104f43d --- /dev/null +++ b/tensorflow/contrib/lite/context_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include + +namespace tflite { + +// NOTE: this tests only the TfLiteIntArray part of context. +// most of context.h is provided in the context of using it with interpreter.h +// and interpreter.cc, so interpreter_test.cc tests context structures more +// thoroughly. + +TEST(IntArray, TestIntArrayCreate) { + TfLiteIntArray* a = TfLiteIntArrayCreate(0); + TfLiteIntArray* b = TfLiteIntArrayCreate(3); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayCopy) { + TfLiteIntArray* a = TfLiteIntArrayCreate(2); + a->data[0] = 22; + a->data[1] = 24; + TfLiteIntArray* b = TfLiteIntArrayCopy(a); + ASSERT_NE(a, b); + ASSERT_EQ(a->size, b->size); + ASSERT_EQ(a->data[0], b->data[0]); + ASSERT_EQ(a->data[1], b->data[1]); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayEqual) { + TfLiteIntArray* a = TfLiteIntArrayCreate(1); + a->data[0] = 1; + TfLiteIntArray* b = TfLiteIntArrayCreate(2); + b->data[0] = 5; + b->data[1] = 6; + TfLiteIntArray* c = TfLiteIntArrayCreate(2); + c->data[0] = 5; + c->data[1] = 6; + TfLiteIntArray* d = TfLiteIntArrayCreate(2); + d->data[0] = 6; + d->data[1] = 6; + ASSERT_FALSE(TfLiteIntArrayEqual(a, b)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, c)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, b)); + ASSERT_FALSE(TfLiteIntArrayEqual(c, d)); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); + TfLiteIntArrayFree(c); + TfLiteIntArrayFree(d); +} + +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc new file mode 100644 index 0000000000..6ba5384a94 --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/error_reporter.h" +#include +#include + +namespace tflite { + +ErrorReporter::~ErrorReporter() {} + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +int StderrReporter::Report(const char* format, va_list args) { + return vfprintf(stderr, format, args); +} + +ErrorReporter* DefaultErrorReporter() { + static StderrReporter* error_reporter = new StderrReporter; + return error_reporter; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h new file mode 100644 index 0000000000..637d456ce7 --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// A functor that reports error to supporting system. Invoked similar to +// printf. +// +// Usage: +// ErrorReporter foo; +// foo.Report("test %d\n", 5); +// or +// va_list args; +// foo.Report("test %d\n", args); // where args is va_list +// +// Sublclass ErrorReporter to provide another reporting destination. +// For example, if you have a GUI program, you might redirect to a buffer +// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter(); + virtual int Report(const char* format, va_list args) = 0; + int Report(const char* format, ...); + int ReportError(void*, const char* format, ...); +}; + +// An error reporter that simplify writes the message to stderr. +struct StderrReporter : public ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +// Return the default error reporter (output to stderr). +ErrorReporter* DefaultErrorReporter(); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc new file mode 100644 index 0000000000..954e236ac8 --- /dev/null +++ b/tensorflow/contrib/lite/interpreter.cc @@ -0,0 +1,567 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/interpreter.h" +#include +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace { + +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 4; +// std::vector preallocation tuning. +constexpr const int kSlotsToReserve = 128; + +} // namespace + +namespace tflite { + +Interpreter::Interpreter(ErrorReporter* error_reporter) + : arena_(kDefaultArenaAlignment), + persistent_arena_(kDefaultArenaAlignment), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + context_.impl_ = static_cast(this); + context_.ResizeTensor = ResizeTensor; + context_.ReportError = ReportError; + context_.AddTensors = AddTensors; + context_.tensors = nullptr; + context_.tensors_size = 0; + context_.gemm_context = nullptr; + // Reserve some space for the tensors to avoid excessive resizing. + tensors_.reserve(kSlotsToReserve); + nodes_and_registration_.reserve(kSlotsToReserve); + next_allocate_node_id_ = 0; + UseNNAPI(false); +} + +Interpreter::~Interpreter() { + for (auto& nodeAndReg : nodes_and_registration_) { + TfLiteNode& node = nodeAndReg.first; + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + TfLiteIntArrayFree(node.temporaries); + if (node.builtin_data) free(node.builtin_data); + OpFree(nodeAndReg.second, node.user_data); + node.builtin_data = nullptr; + } + + for (int i = 0; i < context_.tensors_size; i++) { + TfLiteTensorFree(&context_.tensors[i]); + } +} + +TfLiteStatus Interpreter::SetInputs(std::vector inputs) { + TF_LITE_ENSURE_OK(&context_, + CheckTensorIndices("inputs", inputs.data(), inputs.size())); + inputs_ = std::move(inputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { + TF_LITE_ENSURE_OK( + &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size())); + outputs_ = std::move(outputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::CheckTensorIndices(const char* label, + const int* indices, int length) { + // Making sure kOptionalTensor is not re-defined to something other than -1. + static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1"); + + for (int i = 0; i < length; i++) { + int index = indices[i]; + if (index < kOptionalTensor || index >= context_.tensors_size) { + ReportError(&context_, "Invalid tensor index %d in %s\n", index, label); + consistent_ = false; + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, + int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(&context_, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + ReportError(&context_, + "Only float32, int32, int64, uint8 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() { + if (!consistent_) { + ReportError(&context_, "AllocateTensors() called on inconsistent model."); + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) { + return kTfLiteOk; + } + allocs_and_refcounts_.resize(context_.tensors_size); + + int new_next_allocate_node_id = next_allocate_node_id_; + invokable_ = false; + + // Allocate graph input nodes. + if (next_allocate_node_id_ == 0) { + for (int i = 0; i < inputs_.size(); ++i) { + int tensor_index = inputs_[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Add 1 to output tensors, so they will not get overwritten. + for (int i = 0; i < outputs_.size(); ++i) { + allocs_and_refcounts_[outputs_[i]].count++; + } + } + + // Count references to node input tensors, and resize node-referenced tensors + // until we encounter a node that has a dynamic output tensor. + for (int k = next_allocate_node_id_; k < nodes_and_registration_.size(); + k++) { + new_next_allocate_node_id++; + TfLiteNode& node = nodes_and_registration_[k].first; + const TfLiteRegistration& registration = nodes_and_registration_[k].second; + if (OpPrepare(registration, &node) == kTfLiteError) { + return kTfLiteError; + } + + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index != kOptionalTensor) { + allocs_and_refcounts_[node_inputs->data[i]].count++; + } + } + + // Discontinue if the node has dynamic outputs. + bool has_unallocated_dynamic_tensor = false; + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]]; + if (tensor.allocation_type == kTfLiteDynamic) { + has_unallocated_dynamic_tensor = true; + break; + } + } + if (has_unallocated_dynamic_tensor) { + break; + } + } + + // Allocate graph persistent outputs, e.g. RNN cell states, etc. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // Go through output tensors and allocate the persistent ones first. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK(&context_, + persistent_arena_.Allocate( + &context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Go through the graph in execution order. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // First allocate output tensors. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Then the temporaries, in two passes. First allocate them all, them + // deallocate them. + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + + // Then process the node's inputs. + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + + // Decrease reference count and deallocate if not needed anymore. + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Resize the buffer and commit the arena. + TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_)); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_)); + + // Rewire the tensors to use the underlying arena buffer. + for (int i = 0; i < context_.tensors_size; ++i) { + TfLiteTensor& tensor = context_.tensors[i]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc, + &tensor.data.raw)); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK( + &context_, + persistent_arena_.ResolveAlloc( + &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw)); + } + } + + invokable_ = true; + next_allocate_node_id_ = new_next_allocate_node_id; + return kTfLiteOk; +} + +namespace { +TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; +} +} // namespace + +TfLiteStatus Interpreter::AllocateTensors() { + next_allocate_node_id_ = 0; + TF_LITE_ENSURE_OK(&context_, arena_.Clear()); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear()); + allocs_and_refcounts_.clear(); + return AllocateTensorsWhoseSizesAreKnown(); +} + +TfLiteStatus Interpreter::AddNodeWithParameters( + const std::vector& inputs, const std::vector& outputs, + const char* init_data, size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, int* node_index) { + invokable_ = false; + + std::unique_ptr builtin_data_deleter(builtin_data, + free); + + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(), + inputs.size())); + TF_LITE_ENSURE_OK( + &context_, + CheckTensorIndices("node outputs", outputs.data(), outputs.size())); + + if (node_index) *node_index = nodes_and_registration_.size(); + nodes_and_registration_.resize(nodes_and_registration_.size() + 1); + auto& node_and_reg = nodes_and_registration_.back(); + TfLiteNode& node = node_and_reg.first; + if (node.inputs) TfLiteIntArrayFree(node.inputs); + if (node.outputs) TfLiteIntArrayFree(node.outputs); + if (node.temporaries) TfLiteIntArrayFree(node.temporaries); + + // NOTE, here we are not using move semantics yet, since our internal + // representation isn't std::vector, but in the future we would like to avoid + // copies, so we want the interface to take r-value references now. + node.inputs = convertVectorToTfLiteIntArray(inputs); + node.outputs = convertVectorToTfLiteIntArray(outputs); + node.temporaries = TfLiteIntArrayCreate(0); + if (init_data) { + node.user_data = OpInit(*registration, init_data, init_data_size); + } else { + node.user_data = + OpInit(*registration, + reinterpret_cast(builtin_data_deleter.get()), 0); + } + node.builtin_data = builtin_data_deleter.release(); + node_and_reg.second = *registration; + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, + const std::vector& dims) { + // TODO(aselle): All bounds checks can be implemented as one-sided bounds + // checks by casting to unsigned for efficiency. Profile before doing this. + + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + invokable_ = false; + TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims); + return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); +} + +TfLiteStatus Interpreter::Invoke() { + if (!consistent_) { + ReportError(&context_, "Invoke called on model that is not consistent."); + return kTfLiteError; + } + if (!invokable_) { + ReportError(&context_, "Invoke called on model that is not ready."); + return kTfLiteError; + } + + TfLiteStatus status = kTfLiteOk; + if (nnapi_delegate_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size()) { + TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); + return kTfLiteOk; + } else { + // TODO(aselle): In the future, we would like this to be an + // automatic tflite CPU fallback. + ReportError(&context_, + "NNAPI was requested, but dependent sized tensors " + "being used.\n"); + return kTfLiteError; + } + } + + for (int i = 0; i < nodes_and_registration_.size(); i++) { + // Ensure we have allocated up to this node. The point of this is to + // allocate as much as possible before running any evaluation, but + // dynamic shapes can prevent this from being possible. + if (i >= next_allocate_node_id_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + } + TfLiteNode& node = nodes_and_registration_[i].first; + const TfLiteRegistration& registration = nodes_and_registration_[i].second; + if (OpInvoke(registration, &node) == kTfLiteError) { + status = kTfLiteError; + } + } + return status; +} + +TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context, + TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ResizeTensorImpl + // (this function is static). + return static_cast(context->impl_) + ->ResizeTensorImpl(tensor, new_size); +} + +void Interpreter::ReportErrorImpl(const char* format, va_list args) { + error_reporter_->Report(format, args); +} + +void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) { + va_list args; + va_start(args, format); + auto* f = static_cast(context->impl_); + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ReportErrorImpl + // (this function is static). + f->ReportErrorImpl(format, args); + va_end(args); +} + +TfLiteStatus Interpreter::AddTensors(int tensors_to_add, + int* first_new_tensor_index) { + int base_index = tensors_.size(); + if (first_new_tensor_index) *first_new_tensor_index = base_index; + tensors_.resize(tensors_.size() + tensors_to_add); + for (int i = base_index; i < tensors_.size(); i++) { + memset(&tensors_[i], 0, sizeof(tensors_[i])); + } + context_.tensors = tensors_.data(); + context_.tensors_size = tensors_.size(); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function AddTensors + // (this function is static). + return static_cast(context->impl_) + ->AddTensors(tensors_to_add, first_new_tensor_index); +} + +TfLiteStatus Interpreter::SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation) { + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + // For most tensors we know exactly how much memory is necessary so we can + // ensure the buffer is large enough. However, we need to skip string tensors + // because their sizes change with the contents of the individual strings. + if (type != kTfLiteString) { + size_t required_bytes; + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); + } + invokable_ = false; + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, const_cast(buffer), bytes, + kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +// Set description of inputs/outputs/data/fptrs for node `node_index`. +// This variant assumes an external buffer has been allocated of size +// bytes. The lifetime of buffer must be ensured to be greater or equal +// to Interpreter. +TfLiteStatus Interpreter::SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization) { + invokable_ = false; + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + size_t required_bytes = 0; + if (type != kTfLiteString) { + // These types will be allocated in our arena so we need to record how + // many bytes we will need based on the dimensions. String tensors are + // allocated dynamically and we can't know ahead of time how much space + // they will require. + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + } + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, + /*buffer=*/nullptr, required_bytes, + type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, + nullptr, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. + if (tensor->allocation_type == kTfLiteArenaRw || + tensor->allocation_type == kTfLiteDynamic) { + if (tensor->type != kTfLiteString) { + size_t bytesRequired; + TfLiteStatus status = BytesRequired(tensor->type, new_size->data, + new_size->size, &bytesRequired); + if (status != kTfLiteOk) { + TfLiteIntArrayFree(new_size); + return kTfLiteError; + } + tensor->bytes = bytesRequired; + } + if (tensor->dims) TfLiteIntArrayFree(tensor->dims); + tensor->dims = new_size; + + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->data.raw = nullptr; + } + } else { + // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore + // of fixed size. + TfLiteIntArrayFree(new_size); + ReportError(&context_, "Attempting to resize a fixed-size tensor."); + return kTfLiteError; + } + return kTfLiteOk; +} + +void Interpreter::UseNNAPI(bool enable) { + // TODO(aselle): This is a workaround for finding if NNAPI exists. + // We also need to make sure getLibraryHandle() is renamed to be NNAPI + // prefixed. + if (!NNAPIExists()) enable = false; + if (!enable) { + nnapi_delegate_.reset(); + } else if (!nnapi_delegate_) { + nnapi_delegate_.reset(new NNAPIDelegate); + } +} + +void Interpreter::SetNumThreads(int num_threads) { + // TODO(ahentz): this forces us to link against gemmlowp even when the ops + // don't use it. We should implement some dynamic mechanism for this sort of + // library-specific initialization. + tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h new file mode 100644 index 0000000000..8bf60e91f7 --- /dev/null +++ b/tensorflow/contrib/lite/interpreter.h @@ -0,0 +1,376 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/core/platform/platform.h" + +namespace tflite { + +// Map statically from a c++ type to a TfLiteType (used below for safe casts). +template +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteNoType; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt64; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteFloat32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteUInt8; +} + +struct ArenaAllocRefCount { + ArenaAllocRefCount() : alloc(), count(0) {} + + ArenaAlloc alloc; + int count; +}; + +// Forward declare since NNAPIDelegate uses Interpreter. +class NNAPIDelegate; + +// An interpreter for a graph of nodes that input and output from tensors. +// Each node of the graph processes a set of input tensors and produces a +// set of output Tensors. All inputs/output tensors are referenced by index. +// +// Usage: +// +// -- Create basic model +// Interpreter foo(2, 1); +// foo.SetTensorParametersReadWrite(0, ...); +// foo.SetTensorParametersReadOnly(1, ...); +// foo.SetNodeParameters(0, ...) +// +// -- Resize input array to 1 length. +// foo.ResizeInputTensor(0, 1); +// foo.AllocateTensors(); +// -- Install array data +// foo.typed_tensor(0)[0] = 3; +// foo.Invoke(); +// foo.typed_tensor(0)[0] = 4; +// foo.Invoke(); +// -- Resize input array and set data. +// foo.ResizeInputTensor(0, 2); +// foo.AllocateTensors(); +// foo.typed_tensor(0)[0] = 4; +// foo.typed_tensor(0)[1] = 8; +// foo.Invoke(); +// + +class Interpreter { + public: + // Instantiate an interpreter. All errors associated with reading and + // processing this model will be forwarded to the error_reporter object. + // + // Note, if error_reporter is nullptr, then a default StderrReporter is + // used. + explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); + + ~Interpreter(); + + Interpreter(const Interpreter&) = delete; + Interpreter& operator=(const Interpreter&) = delete; + + // Functions to build interpreter + + // Provide a list of tensor indexes that are inputs to the model. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetInputs(std::vector inputs); + + // Provide a list of tensor indexes that are outputs to the model + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetOutputs(std::vector outputs); + + // Adds a node with the given parameters and returns the index of the new + // node in `node_index` (optionally). Interpreter will take ownership of + // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // remains with the caller. + TfLiteStatus AddNodeWithParameters(const std::vector& inputs, + const std::vector& outputs, + const char* init_data, + size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, + int* node_index = nullptr); + + // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. + // The value pointed to by `first_new_tensor_index` will be set to the + // index of the first new tensor if `first_new_tensor_index` is non-null. + TfLiteStatus AddTensors(int tensors_to_add, + int* first_new_tensor_index = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization); + + // Functions to access tensor data + + // Read only access to list of inputs. + const std::vector& inputs() const { return inputs_; } + + // Return the name of a given input. The given index must be between 0 and + // inputs().size(). + const char* GetInputName(int index) const { + return context_.tensors[inputs_[index]].name; + } + + // Read only access to list of outputs. + const std::vector& outputs() const { return outputs_; } + + // Return the name of a given output. The given index must be between 0 and + // outputs().size(). + const char* GetOutputName(int index) const { + return context_.tensors[outputs_[index]].name; + } + + // Return the number of tensors in the model. + int tensors_size() const { return context_.tensors_size; } + + // Return the number of ops in the model. + int nodes_size() const { return nodes_and_registration_.size(); } + + // Get a tensor data structure. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + TfLiteTensor* tensor(int tensor_index) { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + + // Get a pointer to an operation and registration data structure if in bounds. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + const std::pair* node_and_registration( + int node_index) { + if (node_index >= nodes_and_registration_.size() || node_index < 0) + return nullptr; + return &nodes_and_registration_[node_index]; + } + + // Perform a checked cast to the appropriate tensor type. + template + T* typed_tensor(int tensor_index) { + if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a pointer into the data of a given input tensor. The given index + // must be between 0 and inputs().size(). + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. The given index + // must be between 0 and outputs().size(). + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Change the dimensionality of a given tensor. Note, this is only acceptable + // for tensor indices that are inputs. + // Returns status of failure or success. + // TODO(aselle): Consider implementing ArraySlice equivalent to make this + // more adept at accepting data without an extra copy. Use absl::ArraySlice + // if our partners determine that dependency is acceptable. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Update allocations for all tensors. This will redim dependent tensors using + // the input tensor dimensionality as given. This is relatively expensive. + // If you know that your sizes are not changing, you need not call this. + + // Returns status of success or failure. + // TODO(aselle): Madde + TfLiteStatus AllocateTensors(); + + // Invoke the interpreter (run the whole graph in dependency order). + // + // NOTE: It is possible that the interpreter is not in a ready state + // to evaluate (i.e. if a ResizeTensor() has been performed without an + // AllocateTensors(). + // Returns status of success or failure. + TfLiteStatus Invoke(); + + // Enable or disable the NN API (true to enable) + void UseNNAPI(bool enable); + + // Set the number of threads available to the interpreter. + void SetNumThreads(int num_threads); + + private: + // Give 'op_reg' a chance to initialize itself using the contents of + // 'buffer'. + void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, + size_t length) { + if (op_reg.init == nullptr) return nullptr; + return op_reg.init(&context_, buffer, length); + } + + // Let 'op_reg' release any memory it might have allocated via 'OpInit'. + void OpFree(const TfLiteRegistration& op_reg, void* buffer) { + if (op_reg.free == nullptr) return; + if (buffer) { + op_reg.free(&context_, buffer); + } + } + + // Prepare the given 'node' for execution. + TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.prepare == nullptr) return kTfLiteOk; + return op_reg.prepare(&context_, node); + } + + // Invoke the operator represented by 'node'. + TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.invoke == nullptr) return kTfLiteError; + return op_reg.invoke(&context_, node); + } + + // Allocate tensors whose sizes are known in order of nodes. Discontinue when + // we encounter a node that has a dynamic output tensor. + TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + + // Tensors needed by the interpreter. Use `AddTensors` to add more blank + // tensor entries. Note, `tensors_.data()` needs to be synchronized to the + // `context_` whenever this std::vector is reallocated. Currently this + // only happens in `AddTensors()`. + std::vector tensors_; + + // Check if an array of tensor indices are valid with respect to the Tensor + // array. + // NOTE: this changes consistent_ to be false if indices are out of bounds. + TfLiteStatus CheckTensorIndices(const char* label, const int* indices, + int length); + + // Compute the number of bytes required to represent a tensor with dimensions + // specified by the array dims (of length dims_size). Returns the status code + // and bytes. + TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, + size_t* bytes); + + // Request an tensor be resized implementation. + TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); + + // Report a detailed error string (will be printed to stderr). + // TODO(aselle): allow user of class to provide alternative destinations. + void ReportErrorImpl(const char* format, va_list args); + + // Entry point for C node plugin API to request an tensor be resized. + static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Entry point for C node plugin API to report an error. + static void ReportError(TfLiteContext* context, const char* format, ...); + + // Entry point for C node plugin API to add new tensors. + static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index); + + // A pure C data structure used to communicate with the pure C plugin + // interface. To avoid copying tensor metadata, this is also the definitive + // structure to store tensors. + TfLiteContext context_; + + // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores + // function pointers to actual implementation. + std::vector> + nodes_and_registration_; + + // Raw memory buffer that is allocated for all temporary and graph outputs. + // that are declared kTfLiteArenaRw. + SimpleMemoryArena arena_; + + // Raw memory buffer that is allocated for persistent tensors that are + // declared as kTfLiteArenaRwPersistent. + SimpleMemoryArena persistent_arena_; + + // Stores allocation and reference counts of all tensors. + std::vector allocs_and_refcounts_; + + // Whether the model is consistent. That is to say if the inputs and outputs + // of every node and the global inputs and outputs are valid indexes into + // the tensor array. + bool consistent_ = true; + + // Whether the model is safe to invoke (if any errors occurred this + // will be false). + bool invokable_ = false; + + // Array of indices representing the tensors that are inputs to the + // interpreter. + std::vector inputs_; + + // Array of indices representing the tensors that are outputs to the + // interpreter. + std::vector outputs_; + + // The error reporter delegate that tflite will forward queries errors to. + ErrorReporter* error_reporter_; + + // Next node to allocate output tensors. + // During Invoke(), Interpreter will allocate input tensors first, which are + // known to be fixed size. Then it will allocate outputs from nodes as many + // as possible. When there is a node that produces dynamic sized tensor. + // Intepreter will stop allocating tensors, set the value of next allocate + // node id, and execute the node to generate the output tensor before continue + // to allocate successors. This process repeats until all nodes are executed. + // NOTE: this relies on the order of nodes that is in topological order. + int next_allocate_node_id_; + + // Whether to delegate to NN API + std::unique_ptr nnapi_delegate_; +}; + +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc new file mode 100644 index 0000000000..edff210943 --- /dev/null +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -0,0 +1,526 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/interpreter.h" +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +// Make an interpreter that has no tensors and no nodes +TEST(BasicInterpreter, ZeroInterpreter) { + Interpreter interpreter; + interpreter.SetInputs({}); + interpreter.SetOutputs({}); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test various error conditions. +TEST(BasicInterpreter, InvokeInvalidModel) { + Interpreter interpreter; + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test size accesser functions. +TEST(BasicInterpreter, TestSizeFunctions) { + Interpreter interpreter; + int base_index; + ASSERT_EQ(interpreter.nodes_size(), 0); + ASSERT_EQ(interpreter.tensors_size(), 0); + ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 2); + ASSERT_EQ(base_index, 0); + ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 5); + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 6); + ASSERT_EQ(base_index, 2); +} + +// Test if invalid indices make a model inconsistent (and conversely if +// valid indices keep a model consistent). +TEST(BasicInterpreter, InconsistentModel) { + // Invalid inputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.inputs(), std::vector()); + } + // Invalid outputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.outputs(), std::vector()); + } + // Invalid node inputs + { + Interpreter interpreter; + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + } + // Valid inputs and outputs and a node with valid inputs and outputs + { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + } +} + +// Make an interpreter that has one tensor but no ops +TEST(BasicInterpreter, CheckAllocate) { + struct { + TfLiteType type; + size_t size; + } cases[] = { + {kTfLiteFloat32, sizeof(float)}, + {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, + {kTfLiteInt64, sizeof(int64_t)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant); + interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + } +} + +TEST(BasicInterpreter, CheckResize) { + const float floats[] = {-3., -4.}; + const int32_t int32s[] = {-3, -4}; + const uint8_t uint8s[] = {3, 4}; + const int64_t int64s[] = {6, -7}; + + struct { + TfLiteType type; + size_t size; + const char* array; + } cases[] = { + {kTfLiteFloat32, sizeof(float), reinterpret_cast(floats)}, + {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, + {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, + {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + ASSERT_EQ( + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 2 * test.size), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk); + // Resizing a mmapped tensor is not allowed and should produce error. + ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk); + // Set the tensor to be mmapped but with a buffer size that is insufficient + // to match the dimensionality. + ASSERT_NE(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 1 * test.size), + kTfLiteOk); + // Allocating should work since we should have our last correct array + // values in place. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + } +} + +TEST(BasicInterpreter, CheckAlignment) { + struct { + TfLiteType type; + } cases[] = { + {kTfLiteFloat32}, + {kTfLiteInt32}, + {kTfLiteUInt8}, + {kTfLiteInt64}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk); + + for (int i = 0; i < 4; i++) { + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1}, + quant); + } + interpreter.AllocateTensors(); + for (int i = 0; i < 4; i++) { + const TfLiteTensor& tensor = *interpreter.tensor(i); + ASSERT_EQ(reinterpret_cast(tensor.data.raw) % 4, 0); + } + } +} + +TEST(BasicInterpreter, CheckArenaAllocation) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk); + + TfLiteQuantizationParams quant; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + std::vector sizes{2048, 4096, 1023, 2047, 1021, + 2047, 1023, 2046, 1021, 2048}; + for (int i = 0; i < sizes.size(); ++i) { + interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, + quant); + } + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({9, 4}); + interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); + ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); +} + +TEST(BasicInterpreter, BufferAccess) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // Verify we get a valid pointer.r + ASSERT_NE(interpreter.typed_tensor(0), nullptr); + // Verify incorrect pointer will not returned. + ASSERT_EQ(interpreter.typed_tensor(0), nullptr); + // Verify that raw c interface ptr matches safe interface. + ASSERT_EQ(interpreter.typed_tensor(0), interpreter.tensor(0)->data.f); +} + +TEST(BasicInterpreter, NoOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +TEST(BasicInterpreter, OneOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1", + {3}, quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0", + {3}, quantized), + kTfLiteOk); + + ASSERT_EQ(interpreter.GetInputName(0), "in1"); + ASSERT_EQ(interpreter.GetOutputName(0), "out0"); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.init = [](TfLiteContext* context, const char*, size_t) -> void* { + auto* first_new_tensor = new int; + context->AddTensors(context, 2, first_new_tensor); + return first_new_tensor; + }; + reg.free = [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* first_new_tensor = reinterpret_cast(node->user_data); + + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize)); + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + for (int i = 0; i < 2; ++i) { + node->temporaries->data[i] = *(first_new_tensor) + i; + } + + auto setup_temporary = [&](int id) { + TfLiteTensor* tmp = &context->tensors[id]; + tmp->type = kTfLiteFloat32; + tmp->allocation_type = kTfLiteArenaRw; + return context->ResizeTensor(context, tmp, + TfLiteIntArrayCopy(tensor0->dims)); + }; + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0])); + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1])); + + return kTfLiteOk; + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + + auto populate = [&](int id) { + TfLiteTensor* t = &context->tensors[id]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + t->data.f[i] = a0->data.f[i]; + } + }; + + populate(node->outputs->data[0]); + populate(node->temporaries->data[0]); + populate(node->temporaries->data[1]); + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Forcefully divides tensor allocation in three steps: one before invocation +// and two more at invocation time. This happens because we use string tensors +// and their sizes can't be determined until invocation time. +TEST(BasicInterpreter, ThreeStepAllocate) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'}; + // Read only string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1}, + quantized, data, 15), + kTfLiteOk); + // Read-write string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + + // String-in String-out node. + TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr}; + reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + DynamicBuffer buf; + StringRef str_ref = GetString(a0, 0); + buf.AddString(str_ref); + buf.WriteToTensor(a1); + return kTfLiteOk; + }; + + // String-in Int-out node. + TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr}; + reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + outputSize->data[0] = 1; + return context->ResizeTensor(context, output, outputSize); + }; + reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + a1->data.i32[0] = a0->bytes; + return kTfLiteOk; + }; + + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->bytes, 15); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 15); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(3)->bytes, 15); + ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(2)->bytes, 4); + ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15); + ASSERT_EQ(interpreter.tensor(4)->bytes, 4); + ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15); +} + +TEST(BasicInterpreter, AllocateTwice) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + char* old_tensor0_ptr = interpreter.tensor(0)->data.raw; + char* old_tensor1_ptr = interpreter.tensor(1)->data.raw; + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw); + ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + all_reports += buffer; + calls++; + return size; + } + int calls = 0; + std::string all_reports; +}; + +TEST(BasicInterpreter, TestNullErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter; +} + +TEST(BasicInterpreter, TestCustomErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter(&reporter); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.calls, 1); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { +#ifdef OS_LINUX + FLAGS_logtostderr = true; +#endif + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD new file mode 100644 index 0000000000..80e8bd435e --- /dev/null +++ b/tensorflow/contrib/lite/java/BUILD @@ -0,0 +1,164 @@ +# Description: +# TensorFlow Lite Java API. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") + +android_library( + name = "tensorflowlite", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + ":tflite_runtime", + "@javax_validation", + ], +) + +android_library( + name = "tensorflowlite_java", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + "@javax_validation", + ], +) + +java_library( + name = "tensorflowlitelib", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/src/main/native", + "@javax_validation", + ], +) + +java_test( + name = "TensorFlowLiteTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorFlowLiteTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "DataTypeTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.DataTypeTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "NativeInterpreterWrapperTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/int32.bin", + "src/testdata/int64.bin", + "src/testdata/invalid.model.tflite", + "src/testdata/uint8.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +# TODO: generate large models at runtime, instead of storing them. +java_test( + name = "InterpreterTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/mobilenet.tflite.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.InterpreterTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "TensorTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"], + data = [ + "src/testdata/add.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +filegroup( + name = "libtensorflowlite_jni", + srcs = select({ + "//conditions:default": [":libtensorflowlite_jni.so"], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "tflite_runtime", + srcs = ["libtensorflowlite_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libtensorflowlite_jni.so", + deps = [ + "//tensorflow/contrib/lite/java/src/main/native", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore new file mode 100644 index 0000000000..39fb081a42 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/.gitignore @@ -0,0 +1,9 @@ +*.iml +.gradle +/local.properties +/.idea/workspace.xml +/.idea/libraries +.DS_Store +/build +/captures +.externalNativeBuild diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle new file mode 100644 index 0000000000..e1470fe717 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.tflitecamerademo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + flatDir { + dirs 'libs' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..ba63dce5d9 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD new file mode 100644 index 0000000000..512a86affe --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -0,0 +1,43 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "TfLiteCameraDemo", + srcs = glob(["java/**/*.java"]), + assets = [ + ":assets", + ], + assets_dir = "", + custom_package = "com.example.android.tflitecamerademo", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_mobilenet//:model_files", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000..1a759f5652 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), + visibility = ["//visibility:public"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt new file mode 100644 index 0000000000..fe811239d8 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java new file mode 100644 index 0000000000..f204590659 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + + private int mRatioWidth = 0; + private int mRatioHeight = 0; + + public AutoFitTextureView(Context context) { + this(context, null); + } + + public AutoFitTextureView(Context context, AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(Context context, AttributeSet attrs, int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(int width, int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + mRatioWidth = width; + mRatioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + int width = MeasureSpec.getSize(widthMeasureSpec); + int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == mRatioWidth || 0 == mRatioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * mRatioWidth / mRatioHeight) { + setMeasuredDimension(width, width * mRatioHeight / mRatioWidth); + } else { + setMeasuredDimension(height * mRatioWidth / mRatioHeight, height); + } + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java new file mode 100644 index 0000000000..74737a8b88 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -0,0 +1,708 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.pm.PackageInfo; +import android.content.pm.PackageManager; +import android.content.res.Configuration; +import android.graphics.Bitmap; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.Point; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.support.annotation.NonNull; +import android.support.v13.app.FragmentCompat; +import android.support.v4.content.ContextCompat; +import android.util.Log; +import android.util.Size; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** Basic fragments for the Camera. */ +public class Camera2BasicFragment extends Fragment + implements FragmentCompat.OnRequestPermissionsResultCallback { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + private static final String FRAGMENT_DIALOG = "dialog"; + + private static final String HANDLE_THREAD_NAME = "CameraBackground"; + + private static final int PERMISSIONS_REQUEST_CODE = 1; + + private final Object lock = new Object(); + private boolean runClassifier = false; + private boolean checkedPermissions = false; + private TextView textView; + private ImageClassifier classifier; + + /** Max preview width that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_WIDTH = 1920; + + /** Max preview height that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_HEIGHT = 1080; + + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + + @Override + public void onSurfaceTextureAvailable(SurfaceTexture texture, int width, int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged(SurfaceTexture texture, int width, int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(SurfaceTexture texture) {} + }; + + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + + /** The {@link android.util.Size} of camera preview. */ + private Size previewSize; + + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + + @Override + public void onOpened(@NonNull CameraDevice currentCameraDevice) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = currentCameraDevice; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(@NonNull CameraDevice currentCameraDevice) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + } + + @Override + public void onError(@NonNull CameraDevice currentCameraDevice, int error) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + + /** An {@link ImageReader} that handles image capture. */ + private ImageReader imageReader; + + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private Semaphore cameraOpenCloseLock = new Semaphore(1); + + /** A {@link CameraCaptureSession.CaptureCallback} that handles events related to capture. */ + private CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + + @Override + public void onCaptureProgressed( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull TotalCaptureResult result) {} + }; + + /** + * Shows a {@link Toast} on the UI thread for the classification results. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + textView.setText(text); + } + }); + } + } + + /** + * Resizes image. + * + * Attempting to use too large a preview size could exceed the camera bus' bandwidth limitation, + * resulting in gorgeous previews but the storage of garbage capture data. + * + * Given {@code choices} of {@code Size}s supported by a camera, choose the smallest one that is + * at least as large as the respective texture view size, and that is at most as large as the + * respective max size, and whose aspect ratio matches with the specified value. If such size + * doesn't exist, choose the largest one that is at most as large as the respective max size, and + * whose aspect ratio matches with the specified value. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param textureViewWidth The width of the texture view relative to sensor coordinate + * @param textureViewHeight The height of the texture view relative to sensor coordinate + * @param maxWidth The maximum width that can be chosen + * @param maxHeight The maximum height that can be chosen + * @param aspectRatio The aspect ratio + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + private static Size chooseOptimalSize( + Size[] choices, + int textureViewWidth, + int textureViewHeight, + int maxWidth, + int maxHeight, + Size aspectRatio) { + + // Collect the supported resolutions that are at least as big as the preview Surface + List bigEnough = new ArrayList<>(); + // Collect the supported resolutions that are smaller than the preview Surface + List notBigEnough = new ArrayList<>(); + int w = aspectRatio.getWidth(); + int h = aspectRatio.getHeight(); + for (Size option : choices) { + if (option.getWidth() <= maxWidth + && option.getHeight() <= maxHeight + && option.getHeight() == option.getWidth() * h / w) { + if (option.getWidth() >= textureViewWidth && option.getHeight() >= textureViewHeight) { + bigEnough.add(option); + } else { + notBigEnough.add(option); + } + } + } + + // Pick the smallest of those big enough. If there is no one big enough, pick the + // largest of those not big enough. + if (bigEnough.size() > 0) { + return Collections.min(bigEnough, new CompareSizesByArea()); + } else if (notBigEnough.size() > 0) { + return Collections.max(notBigEnough, new CompareSizesByArea()); + } else { + Log.e(TAG, "Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static Camera2BasicFragment newInstance() { + return new Camera2BasicFragment(); + } + + /** Layout the preview and buttons. */ + @Override + public View onCreateView( + LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) { + return inflater.inflate(R.layout.fragment_camera2_basic, container, false); + } + + /** Connect the buttons to their event handler. */ + @Override + public void onViewCreated(final View view, Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + textView = (TextView) view.findViewById(R.id.text); + } + + /** Load the model and labels. */ + @Override + public void onActivityCreated(Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + try { + classifier = new ImageClassifier(getActivity()); + } catch (IOException e) { + Log.e(TAG, "Failed to initialize an image classifier."); + } + startBackgroundThread(); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + @Override + public void onDestroy() { + classifier.close(); + super.onDestroy(); + } + + /** + * Sets up member variables related to camera. + * + * @param width The width of available size for camera preview + * @param height The height of available size for camera preview + */ + private void setUpCameraOutputs(int width, int height) { + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + for (String cameraId : manager.getCameraIdList()) { + CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + if (map == null) { + continue; + } + + // // For still image captures, we use the largest available size. + Size largest = + Collections.max( + Arrays.asList(map.getOutputSizes(ImageFormat.JPEG)), new CompareSizesByArea()); + imageReader = + ImageReader.newInstance( + largest.getWidth(), largest.getHeight(), ImageFormat.JPEG, /*maxImages*/ 2); + + // Find out if we need to swap dimension to get the preview size relative to sensor + // coordinate. + int displayRotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + // noinspection ConstantConditions + /* Orientation of the camera sensor */ + int sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + boolean swappedDimensions = false; + switch (displayRotation) { + case Surface.ROTATION_0: + case Surface.ROTATION_180: + if (sensorOrientation == 90 || sensorOrientation == 270) { + swappedDimensions = true; + } + break; + case Surface.ROTATION_90: + case Surface.ROTATION_270: + if (sensorOrientation == 0 || sensorOrientation == 180) { + swappedDimensions = true; + } + break; + default: + Log.e(TAG, "Display rotation is invalid: " + displayRotation); + } + + Point displaySize = new Point(); + activity.getWindowManager().getDefaultDisplay().getSize(displaySize); + int rotatedPreviewWidth = width; + int rotatedPreviewHeight = height; + int maxPreviewWidth = displaySize.x; + int maxPreviewHeight = displaySize.y; + + if (swappedDimensions) { + rotatedPreviewWidth = height; + rotatedPreviewHeight = width; + maxPreviewWidth = displaySize.y; + maxPreviewHeight = displaySize.x; + } + + if (maxPreviewWidth > MAX_PREVIEW_WIDTH) { + maxPreviewWidth = MAX_PREVIEW_WIDTH; + } + + if (maxPreviewHeight > MAX_PREVIEW_HEIGHT) { + maxPreviewHeight = MAX_PREVIEW_HEIGHT; + } + + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + rotatedPreviewWidth, + rotatedPreviewHeight, + maxPreviewWidth, + maxPreviewHeight, + largest); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + + this.cameraId = cameraId; + return; + } + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + } + } + + private String[] getRequiredPermissions() { + Activity activity = getActivity(); + try { + PackageInfo info = + activity + .getPackageManager() + .getPackageInfo(activity.getPackageName(), PackageManager.GET_PERMISSIONS); + String[] ps = info.requestedPermissions; + if (ps != null && ps.length > 0) { + return ps; + } else { + return new String[0]; + } + } catch (Exception e) { + return new String[0]; + } + } + + /** Opens the camera specified by {@link Camera2BasicFragment#cameraId}. */ + private void openCamera(int width, int height) { + if (!checkedPermissions && !allPermissionsGranted()) { + FragmentCompat.requestPermissions(this, getRequiredPermissions(), PERMISSIONS_REQUEST_CODE); + return; + } else { + checkedPermissions = true; + } + setUpCameraOutputs(width, height); + configureTransform(width, height); + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + private boolean allPermissionsGranted() { + for (String permission : getRequiredPermissions()) { + if (ContextCompat.checkSelfPermission(getActivity(), permission) + != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != imageReader) { + imageReader.close(); + imageReader = null; + } + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread(HANDLE_THREAD_NAME); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + synchronized (lock) { + runClassifier = true; + } + backgroundHandler.post(periodicClassify); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + synchronized (lock) { + runClassifier = false; + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + /** Takes photos and classify them periodically. */ + private Runnable periodicClassify = + new Runnable() { + @Override + public void run() { + synchronized (lock) { + if (runClassifier) { + classifyFrame(); + } + } + backgroundHandler.post(periodicClassify); + } + }; + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(@NonNull CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + @Override + public void onConfigureFailed(@NonNull CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + /** + * Configures the necessary {@link android.graphics.Matrix} transformation to `textureView`. This + * method should be called after the camera preview size is determined in setUpCameraOutputs and + * also the size of `textureView` is fixed. + * + * @param viewWidth The width of `textureView` + * @param viewHeight The height of `textureView` + */ + private void configureTransform(int viewWidth, int viewHeight) { + Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + Matrix matrix = new Matrix(); + RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + float centerX = viewRect.centerX(); + float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** Classifies a frame from the preview stream. */ + private void classifyFrame() { + if (classifier == null || getActivity() == null || cameraDevice == null) { + showToast("Uninitialized Classifier or invalid context."); + return; + } + Bitmap bitmap = + textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y); + String textToShow = classifier.classifyFrame(bitmap); + bitmap.recycle(); + showToast(textToShow); + } + + /** Compares two {@code Size}s based on their areas. */ + private static class CompareSizesByArea implements Comparator { + + @Override + public int compare(Size lhs, Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(String message) { + ErrorDialog dialog = new ErrorDialog(); + Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(DialogInterface dialogInterface, int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java new file mode 100644 index 0000000000..e7161ddb26 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.os.Bundle; + +/** Main {@code Activity} class for the Camera app. */ +public class CameraActivity extends Activity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + if (null == savedInstanceState) { + getFragmentManager() + .beginTransaction() + .replace(R.id.container, Camera2BasicFragment.newInstance()) + .commit(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java new file mode 100644 index 0000000000..e7bad46370 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.graphics.Bitmap; +import android.os.SystemClock; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; + +/** Classifies images with Tensorflow Lite. */ +public class ImageClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + /** Name of the model file stored in Assets. */ + private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + /** Number of results to show in the UI. */ + private static final int RESULTS_TO_SHOW = 3; + + /** Dimensions of inputs. */ + private static final int DIM_BATCH_SIZE = 1; + + private static final int DIM_PIXEL_SIZE = 3; + + static final int DIM_IMG_SIZE_X = 224; + static final int DIM_IMG_SIZE_Y = 224; + + /* Preallocated buffers for storing image data in. */ + private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y]; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter tflite; + + /** Labels corresponding to the output of the vision model. */ + private List labelList; + + /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ + private ByteBuffer imgData = null; + + /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ + private byte[][] labelProbArray = null; + + private PriorityQueue> sortedLabels = + new PriorityQueue<>( + RESULTS_TO_SHOW, + new Comparator>() { + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + return (o1.getValue()).compareTo(o2.getValue()); + } + }); + + /** Initializes an {@code ImageClassifier}. */ + ImageClassifier(Activity activity) throws IOException { + tflite = new Interpreter(loadModelFile(activity)); + labelList = loadLabelList(activity); + imgData = + ByteBuffer.allocateDirect( + DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); + imgData.order(ByteOrder.nativeOrder()); + labelProbArray = new byte[1][labelList.size()]; + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Classifies a frame from the preview stream. */ + String classifyFrame(Bitmap bitmap) { + if (tflite == null) { + Log.e(TAG, "Image classifier has not been initialized; Skipped."); + return "Uninitialized Classifier."; + } + convertBitmapToByteBuffer(bitmap); + // Here's where the magic happens!!! + long startTime = SystemClock.uptimeMillis(); + tflite.run(imgData, labelProbArray); + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); + String textToShow = printTopKLabels(); + textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; + return textToShow; + } + + /** Closes tflite to release resources. */ + public void close() { + tflite.close(); + tflite = null; + } + + /** Reads label list from Assets. */ + private List loadLabelList(Activity activity) throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH))); + String line; + while ((line = reader.readLine()) != null) { + labelList.add(line); + } + reader.close(); + return labelList; + } + + /** Memory-map the model file in Assets. */ + private MappedByteBuffer loadModelFile(Activity activity) throws IOException { + AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + /** Writes Image data into a {@code ByteBuffer}. */ + private void convertBitmapToByteBuffer(Bitmap bitmap) { + if (imgData == null) { + return; + } + imgData.rewind(); + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + // Convert the image to floating point. + int pixel = 0; + long startTime = SystemClock.uptimeMillis(); + for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { + for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { + final int val = intValues[pixel++]; + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); + } + + /** Prints top-K labels, to be shown in UI as the results. */ + private String printTopKLabels() { + for (int i = 0; i < labelList.size(); ++i) { + sortedLabels.add( + new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f)); + if (sortedLabels.size() > RESULTS_TO_SHOW) { + sortedLabels.poll(); + } + } + String textToShow = ""; + final int size = sortedLabels.size(); + for (int i = 0; i < size; ++i) { + Map.Entry label = sortedLabels.poll(); + textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow; + } + return textToShow; + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000..e0a70008b1 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000..c22509d8df Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000..a84e3ef52c Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000..520c2dd100 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000..d68af39186 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000..1347b09198 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000..15e419b7cc Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000..fd933333b7 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000..342ce34e16 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml new file mode 100644 index 0000000000..a84f1bbfa0 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml new file mode 100644 index 0000000000..286e549c65 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml new file mode 100644 index 0000000000..15305c436e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -0,0 +1,45 @@ + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml new file mode 100644 index 0000000000..22074a2bdb --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml @@ -0,0 +1,24 @@ + + + + + + + @dimen/margin_huge + @dimen/margin_medium + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml new file mode 100644 index 0000000000..03d1974183 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml @@ -0,0 +1,25 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml new file mode 100644 index 0000000000..8c1ea66f28 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml @@ -0,0 +1,22 @@ + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml new file mode 100644 index 0000000000..ab7d3fd496 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -0,0 +1,30 @@ + + + + + TfLiteCameraDemo + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000..4b75d2b2bd --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml @@ -0,0 +1,19 @@ + + + + #cc4285f4 + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000..a08ec3eb62 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml @@ -0,0 +1,24 @@ + + + Picture + Info + This sample needs camera permission. + This device doesn\'t support Camera2 API. + NN:On + NN:Off + Use NNAPI + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000..3f3bdfb494 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -0,0 +1,18 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/contrib/lite/java/demo/build.gradle new file mode 100644 index 0000000000..b78a0b86c9 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/build.gradle @@ -0,0 +1,23 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:2.3.1' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/contrib/lite/java/demo/gradle.properties new file mode 100644 index 0000000000..aac7c9b461 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. + +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. + +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html + +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m + +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000..13372aef5e Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar differ diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000..fa7a38a0e4 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Thu Sep 28 09:01:41 PDT 2017 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/contrib/lite/java/demo/gradlew new file mode 100755 index 0000000000..9d82f78915 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/contrib/lite/java/demo/gradlew.bat new file mode 100644 index 0000000000..8a0b282aa6 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windowz variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/contrib/lite/java/demo/settings.gradle new file mode 100644 index 0000000000..e7b4def49c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java new file mode 100644 index 0000000000..d63c299589 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Type of elements in a {@link TfLiteTensor}. */ +enum DataType { + /** 32-bit single precision floating point. */ + FLOAT32(1), + + /** 32-bit signed integer. */ + INT32(2), + + /** 8-bit unsigned integer. */ + UINT8(3), + + /** 64-bit signed integer. */ + INT64(4), + + /** A {@link ByteBuffer}. */ + BYTEBUFFER(999); + + private final int value; + + DataType(int value) { + this.value = value; + } + + /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */ + int getNumber() { + return value; + } + + /** Converts an integer to the corresponding type. */ + static DataType fromNumber(int c) { + for (DataType t : values) { + if (t.value == c) { + return t; + } + } + throw new IllegalArgumentException( + "DataType " + c + " is not recognized in Java (version " + TensorFlowLite.version() + ")"); + } + + /** Returns byte size of the type. */ + int elemByteSize() { + switch (this) { + case FLOAT32: + return 4; + case INT32: + return 4; + case UINT8: + return 1; + case INT64: + return 8; + case BYTEBUFFER: + return 1; + } + throw new IllegalArgumentException("DataType " + this + " is not supported yet"); + } + + // Cached to avoid copying it + private static final DataType[] values = values(); +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java new file mode 100644 index 0000000000..dd883d69d2 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; +import javax.validation.constraints.NotNull; + +/** + * Driver class to drive model inference with TensorFlow Lite. + * + *

A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations + * are executed for model inference. + * + *

For example, if a model takes only one input and returns only one output: + * + *

{@code
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.run(input, output);
+ * }
+ * }
+ * + *

If a model takes multiple inputs or outputs: + * + *

{@code
+ * Object[] inputs = {input0, input1, ...};
+ * Map map_of_indices_to_outputs = new HashMap<>();
+ * float[][][] ith_output = new float[3][2][4];
+ * map_of_indices_to_outputs.put(i, ith_output);
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+ * }
+ * }
+ * + *

Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite + * model with Toco. + * + *

WARNING:Instances of a {@code Interpreter} is not thread-safe. A {@code + * Interpreter} owns resources that must be explicitly freed by invoking {@link #close()} + */ +public final class Interpreter implements AutoCloseable { + + /** + * Initializes a {@code Interpreter} + * + * @param modelFile: a File of a pre-trained TF Lite model. + */ + public Interpreter(@NotNull File modelFile) { + if (modelFile == null) { + return; + } + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath()); + } + + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + } + + /** + * Runs model inference if the model takes only one input, and provides only one output. + * + * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types + * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large + * input data. When {@link ByteBuffer} is used, its content should remain unchanged until + * model inference is done. + * @param output a multidimensional array of output data. + */ + public void run(@NotNull Object input, @NotNull Object output) { + Object[] inputs = {input}; + Map outputs = new HashMap<>(); + outputs.put(0, output); + runForMultipleInputsOutputs(inputs, outputs); + } + + /** + * Runs model inference if the model takes multiple inputs, or returns multiple outputs. + * + * @param inputs an array of input data. The inputs should be in the same order as inputs of the + * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of + * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred + * way to pass large input data. When {@link ByteBuffer} is used, its content should remain + * unchanged until model inference is done. + * @param outputs a map mapping output indices to multidimensional arrays of output data. It only + * needs to keep entries for the outputs to be used. + */ + public void runForMultipleInputsOutputs( + @NotNull Object[] inputs, @NotNull Map outputs) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + Tensor[] tensors = wrapper.run(inputs); + if (outputs == null || tensors == null || outputs.size() > tensors.length) { + throw new IllegalArgumentException("Outputs do not match with model outputs."); + } + final int size = tensors.length; + for (Integer idx : outputs.keySet()) { + if (idx == null || idx < 0 || idx >= size) { + throw new IllegalArgumentException( + String.format("Invalid index of output %d (should be in range [0, %d))", idx, size)); + } + tensors[idx].copyTo(outputs.get(idx)); + } + } + + /** + * Resizes idx-th input of the native model to the given dims. + * + *

IllegalArgumentException will be thrown if it fails to resize. + */ + public void resizeInput(int idx, @NotNull int[] dims) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + wrapper.resizeInput(idx, dims); + } + + /** + * Gets index of an input given the op name of the input. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getInputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getInputIndex(opName); + } + + /** + * Gets index of an output given the op name of the output. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getOutputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getOutputIndex(opName); + } + + /** Release resources associated with the {@code Interpreter}. */ + @Override + public void close() { + wrapper.close(); + wrapper = null; + } + + NativeInterpreterWrapper wrapper; +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java new file mode 100644 index 0000000000..1939a078ad --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * A wrapper wraps native interpreter and controls model execution. + * + *

WARNING: Resources consumed by the {@code NativeInterpreterWrapper} object must be + * explicitly freed by invoking the {@link #close()} method when the {@code + * NativeInterpreterWrapper} object is no longer needed. + */ +final class NativeInterpreterWrapper implements AutoCloseable { + + NativeInterpreterWrapper(String modelPath) { + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModel(modelPath, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** + * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The + * MappedByteBuffer should not be modified after the construction of a {@code + * NativeInterpreterWrapper}. + */ + NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { + modelByteBuffer = mappedByteBuffer; + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ + @Override + public void close() { + delete(errorHandle, modelHandle, interpreterHandle); + errorHandle = 0; + modelHandle = 0; + interpreterHandle = 0; + modelByteBuffer = null; + inputsIndexes = null; + outputsIndexes = null; + } + + /** Sets inputs, runs model inference and returns outputs. */ + Tensor[] run(Object[] inputs) { + if (inputs == null || inputs.length == 0) { + throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty."); + } + int[] dataTypes = new int[inputs.length]; + Object[] sizes = new Object[inputs.length]; + int[] numsOfBytes = new int[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + DataType dataType = dataTypeOf(inputs[i]); + dataTypes[i] = dataType.getNumber(); + if (dataType == DataType.BYTEBUFFER) { + ByteBuffer buffer = (ByteBuffer) inputs[i]; + if (buffer.order() != ByteOrder.nativeOrder()) { + throw new IllegalArgumentException( + "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder()."); + } + numsOfBytes[i] = buffer.limit(); + sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); + } else if (isNonEmptyArray(inputs[i])) { + int[] dims = shapeOf(inputs[i]); + sizes[i] = dims; + numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); + } else { + throw new IllegalArgumentException( + String.format( + "%d-th element of the %d inputs is not an array or a ByteBuffer.", + i, inputs.length)); + } + } + long[] outputsHandles = + run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); + if (outputsHandles == null || outputsHandles.length == 0) { + throw new IllegalStateException("Interpreter has no outputs."); + } + Tensor[] outputs = new Tensor[outputsHandles.length]; + for (int i = 0; i < outputsHandles.length; ++i) { + outputs[i] = Tensor.fromHandle(outputsHandles[i]); + } + return outputs; + } + + /** Resizes dimensions of a specific input. */ + void resizeInput(int idx, int[] dims) { + resizeInput(interpreterHandle, errorHandle, idx, dims); + } + + void setUseNNAPI(boolean useNNAPI) { + useNNAPI(interpreterHandle, useNNAPI); + } + + /** Gets index of an input given its name. */ + int getInputIndex(String name) { + if (inputsIndexes == null) { + String[] names = getInputNames(interpreterHandle); + inputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + inputsIndexes.put(names[i], i); + } + } + } + if (inputsIndexes.containsKey(name)) { + return inputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any input. The indexes of the inputs are %s", + name, inputsIndexes.toString())); + } + } + + /** Gets index of an output given its name. */ + int getOutputIndex(String name) { + if (outputsIndexes == null) { + String[] names = getOutputNames(interpreterHandle); + outputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + outputsIndexes.put(names[i], i); + } + } + } + if (outputsIndexes.containsKey(name)) { + return outputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any output. The indexes of the outputs are %s", + name, outputsIndexes.toString())); + } + } + + static int numElements(int[] shape) { + if (shape == null) { + return 0; + } + int n = 1; + for (int i = 0; i < shape.length; i++) { + n *= shape[i]; + } + return n; + } + + static boolean isNonEmptyArray(Object o) { + return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); + } + + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } else if (ByteBuffer.class.isInstance(o)) { + return DataType.BYTEBUFFER; + } + } + throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + + private static final int ERROR_BUFFER_SIZE = 512; + + private long errorHandle; + + private long interpreterHandle; + + private long modelHandle; + + private int inputSize; + + private MappedByteBuffer modelByteBuffer; + + private Map inputsIndexes; + + private Map outputsIndexes; + + private static native String[] getInputNames(long interpreterHandle); + + private static native String[] getOutputNames(long interpreterHandle); + + private static native void resizeInput( + long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + + private static native void useNNAPI(long interpreterHandle, boolean state); + + private static native long createErrorReporter(int size); + + private static native long createModel(String modelPathOrBuffer, long errorHandle); + + private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); + + private static native long createInterpreter(long modelHandle); + + private static native long[] run( + long interpreterHandle, + long errorHandle, + Object[] sizes, + int[] dtypes, + int[] numsOfBytes, + Object[] values); + + private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); + + private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java new file mode 100644 index 0000000000..54ace6c63c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.util.Arrays; + +/** + * A typed multi-dimensional array used in Tensorflow Lite. + * + *

The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not + * needed to be closed here. + */ +final class Tensor { + + static Tensor fromHandle(long nativeHandle) { + return new Tensor(nativeHandle); + } + + /** Reads Tensor content into an array. */ + T copyTo(T dst) { + if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + throw new IllegalArgumentException( + String.format( + "Cannot convert an TensorFlowLite tensor with type %s to a Java object of " + + "type %s (which is compatible with the TensorFlowLite type %s)", + dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + } + int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); + if (!Arrays.equals(dstShape, shapeCopy)) { + throw new IllegalArgumentException( + String.format( + "Shape of output target %s does not match with the shape of the Tensor %s.", + Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + } + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + final long nativeHandle; + final DataType dtype; + final int[] shapeCopy; + + private Tensor(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.dtype = DataType.fromNumber(dtype(nativeHandle)); + this.shapeCopy = shape(nativeHandle); + } + + private static native int dtype(long handle); + + private static native int[] shape(long handle); + + private static native void readMultiDimensionalArray(long handle, Object value); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java new file mode 100644 index 0000000000..711638a9f9 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Static utility methods loading the TensorFlowLite runtime. */ +public final class TensorFlowLite { + + private static final String LIBNAME = "tensorflowlite_jni"; + + private TensorFlowLite() {} + + /** Returns the version of the underlying TensorFlowLite runtime. */ + public static native String version(); + + /** + * Load the TensorFlowLite runtime C library. + */ + static boolean init() { + try { + System.loadLibrary(LIBNAME); + return true; + } catch (UnsatisfiedLinkError e) { + System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage()); + return false; + } + } + + static { + init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java new file mode 100644 index 0000000000..68e6a0f578 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java @@ -0,0 +1,17 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Defines classes to load and execute TensorFlowLite models. */ +package org.tensorflow.lite; diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD new file mode 100644 index 0000000000..9c172a1f68 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -0,0 +1,70 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Lite Java API using the TensorFlow Lite CC library. + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "native_framework_only", + srcs = [ + "exception_jni.cc", + "nativeinterpreterwrapper_jni.cc", + "tensor_jni.cc", + "tensorflow_lite_jni.cc", + ], + hdrs = [ + "exception_jni.h", + "nativeinterpreterwrapper_jni.h", + "tensor_jni.h", + "tensorflow_lite_jni.h", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + ], + alwayslink = 1, +) + +# This includes all ops. If you want a smaller binary, you should copy and +# modify builtin_ops_jni.cc. You should then link your binary against both +# ":native_framework_only" and your own version of ":native_builtin_ops". +cc_library( + name = "native", + srcs = [ + "builtin_ops_jni.cc", + ], + copts = tflite_copts(), + deps = [ + ":native_framework_only", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +exports_files( + [ + "version_script.lds", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc new file mode 100644 index 0000000000..cce356370f --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { + +// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in +// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the +// builtin ops. For smaller binary sizes users should avoid linking this in, and +// should provide a custom make CreateOpResolver() instead. +std::unique_ptr CreateOpResolver() { // NOLINT + return std::unique_ptr( + new tflite::ops::builtin::BuiltinOpResolver()); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc new file mode 100644 index 0000000000..1578c9e3dd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; +const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kUnsupportedOperationException[] = + "java/lang/UnsupportedOperationException"; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + const size_t max_msg_len = 512; + auto* message = static_cast(malloc(max_msg_len)); + if (vsnprintf(message, max_msg_len, fmt, args) >= 0) { + env->ThrowNew(env->FindClass(clazz), message); + } else { + env->ThrowNew(env->FindClass(clazz), ""); + } + free(message); + va_end(args); +} + +BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) { + buffer_ = new char[limit]; + if (!buffer_) { + throwException(env, kNullPointerException, + "Malloc of BufferErrorReporter to hold %d char failed.", + limit); + return; + } + start_idx_ = 0; + end_idx_ = limit - 1; +} + +BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; } + +int BufferErrorReporter::Report(const char* format, va_list args) { + int size = 0; + if (start_idx_ < end_idx_) { + size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args); + } + start_idx_ += size; + return size; +} + +const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; } diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h new file mode 100644 index 0000000000..3ffff052df --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +extern const char kIllegalArgumentException[]; +extern const char kIllegalStateException[]; +extern const char kNullPointerException[]; +extern const char kIndexOutOfBoundsException[]; +extern const char kUnsupportedOperationException[]; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +class BufferErrorReporter : public tflite::ErrorReporter { + public: + BufferErrorReporter(JNIEnv* env, int limit); + virtual ~BufferErrorReporter(); + int Report(const char* format, va_list args) override; + const char* CachedErrorMessage(); + + private: + char* buffer_; + int start_idx_ = 0; + int end_idx_ = 0; +}; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc new file mode 100644 index 0000000000..bc6462eb54 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -0,0 +1,446 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" + +namespace { + +const int kByteBufferValue = 999; +const int kBufferSize = 256; + +tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to Interpreter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, "Invalid handle to model."); + return nullptr; + } + return reinterpret_cast(handle); +} + +BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to ErrorReporter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +std::vector convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { + int size = static_cast(env->GetArrayLength(inputs)); + std::vector outputs(size, 0); + jint* ptr = env->GetIntArrayElements(inputs, nullptr); + if (ptr == nullptr) { + throwException(env, kIllegalArgumentException, + "Empty dimensions of input array."); + return {}; + } + for (int i = 0; i < size; ++i) { + outputs[i] = ptr[i]; + } + env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT); + return outputs; +} + +bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } + +TfLiteType resolveDataType(jint data_type) { + switch (data_type) { + case 1: + return kTfLiteFloat32; + case 2: + return kTfLiteInt32; + case 3: + return kTfLiteUInt8; + case 4: + return kTfLiteInt64; + default: + return kTfLiteNoType; + } +} + +void printDims(char* buffer, int max_size, int* dims, int num_dims) { + if (max_size <= 0) return; + buffer[0] = '?'; + int size = 1; + for (int i = 1; i < num_dims; ++i) { + if (max_size > size) { + int written_size = + snprintf(buffer + size, max_size - size, ",%d", dims[i]); + if (written_size < 0) return; + size += written_size; + } + } +} + +TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, + const int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values, + jobjectArray sizes) { + if (input_size != interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Expected num of inputs is %d but got %d", + interpreter->inputs().size(), input_size); + return kTfLiteError; + } + if (input_size != env->GetArrayLength(data_types) || + input_size != env->GetArrayLength(nums_of_bytes) || + input_size != env->GetArrayLength(values)) { + throwException(env, kIllegalArgumentException, + "Arrays in arguments should be of the same length, but got " + "%d sizes, %d data_types, %d nums_of_bytes, and %d values", + input_size, env->GetArrayLength(data_types), + env->GetArrayLength(nums_of_bytes), + env->GetArrayLength(values)); + return kTfLiteError; + } + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + int num_dims = static_cast(env->GetArrayLength(dims)); + if (target->dims->size != num_dims) { + throwException(env, kIllegalArgumentException, + "%d-th input should have %d dimensions, but found %d " + "dimensions", + i, target->dims->size, num_dims); + return kTfLiteError; + } + jint* ptr = env->GetIntArrayElements(dims, nullptr); + for (int j = 1; j < num_dims; ++j) { + if (target->dims->data[j] != ptr[j]) { + std::unique_ptr expected_dims(new char[kBufferSize]); + std::unique_ptr obtained_dims(new char[kBufferSize]); + printDims(expected_dims.get(), kBufferSize, target->dims->data, + num_dims); + printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); + throwException(env, kIllegalArgumentException, + "%d-th input dimension should be [%s], but found [%s]", + i, expected_dims.get(), obtained_dims.get()); + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + return kTfLiteError; + } + } + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jobjectArray sizes) { + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + TfLiteStatus status = interpreter->ResizeInputTensor( + input_idx, convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + return status; + } + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values) { + jint* data_type = env->GetIntArrayElements(data_types, nullptr); + jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jobject value = env->GetObjectArrayElement(values, i); + bool is_byte_buffer = isByteBuffer(data_type[i]); + if (is_byte_buffer) { + writeByteBuffer(env, value, &(target->data.raw), + static_cast(num_bytes[i])); + } else { + TfLiteType type = resolveDataType(data_type[i]); + if (type != target->type) { + throwException(env, kIllegalArgumentException, + "DataType (%d) of input data does not match with the " + "DataType (%d) of model inputs.", + type, target->type); + return kTfLiteError; + } + writeMultiDimensionalArray(env, value, target->type, target->dims->size, + &(target->data.raw), + static_cast(num_bytes[i])); + } + env->DeleteLocalRef(value); + if (env->ExceptionCheck()) return kTfLiteError; + } + env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); + env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); + return kTfLiteOk; +} + +} // namespace + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get input names."); + return nullptr; + } + size_t size = interpreter->inputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement(names, i, + env->NewStringUTF(interpreter->GetInputName(i))); + } + return names; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get output names."); + return nullptr; + } + size_t size = interpreter->outputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement( + names, i, env->NewStringUTF(interpreter->GetOutputName(i))); + } + return names; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->UseNNAPI(static_cast(state)); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size) { + BufferErrorReporter* error_reporter = + new BufferErrorReporter(env, static_cast(size)); + return reinterpret_cast(error_reporter); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* path = env->GetStringUTFChars(model_file, nullptr); + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "Contents of %s does not encode a valid TensorFlowLite " + "model: %s", + path, error_reporter->CachedErrorMessage()); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + env->ReleaseStringUTFChars(model_file, path); + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* buf = + static_cast(env->GetDirectBufferAddress(model_buffer)); + jlong capacity = env->GetDirectBufferCapacity(model_buffer); + auto model = tflite::FlatBufferModel::BuildFromBuffer( + buf, static_cast(capacity), error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer does not encode a valid TensorFlowLite " + "model: %s", + error_reporter->CachedErrorMessage()); + return 0; + } + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle) { + tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); + if (model == nullptr) return 0; + auto resolver = ::tflite::CreateOpResolver(); + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + return reinterpret_cast(interpreter.release()); +} + +// Sets inputs, runs inference, and returns outputs as long handles. +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values) { + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return nullptr; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return nullptr; + const int input_size = env->GetArrayLength(sizes); + // validates inputs + TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, + nums_of_bytes, values, sizes); + if (status != kTfLiteOk) return nullptr; + // resizes inputs + status = resizeInputs(env, interpreter, input_size, sizes); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, "Can not resize the input: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the given inputs: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // sets inputs + status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, + values); + if (status != kTfLiteOk) return nullptr; + // runs inference + if (interpreter->Invoke() != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to run on the given Interpreter: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // returns outputs + const std::vector& results = interpreter->outputs(); + if (results.empty()) { + throwException(env, kIllegalArgumentException, + "The Interpreter does not have any outputs."); + return nullptr; + } + jlongArray outputs = env->NewLongArray(results.size()); + size_t size = results.size(); + for (int i = 0; i < size; ++i) { + TfLiteTensor* source = interpreter->tensor(results[i]); + jlong output = reinterpret_cast(source); + env->SetLongArrayRegion(outputs, i, 1, &output); + } + return outputs; +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + const int idx = static_cast(input_idx); + if (input_idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Out of range: Failed to get %d-th input out of %d inputs", + input_idx, interpreter->inputs().size()); + return nullptr; + } + TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); + int size = target->dims->size; + int expected_num_bytes = elementByteSize(target->type); + for (int i = 0; i < size; ++i) { + expected_num_bytes *= target->dims->data[i]; + } + if (num_bytes != expected_num_bytes) { + throwException(env, kIllegalArgumentException, + "Failed to get input dimensions. %d-th input should have" + " %d bytes, but found %d bytes.", + idx, expected_num_bytes, num_bytes); + return nullptr; + } + jintArray outputs = env->NewIntArray(size); + env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); + return outputs; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return; + const int idx = static_cast(input_idx); + if (idx < 0 || idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Can not resize %d-th input for a model having %d inputs.", + idx, interpreter->inputs().size()); + } + TfLiteStatus status = interpreter->ResizeInputTensor( + interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to resize %d-th input: %s", idx, + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle) { + if (interpreter_handle != 0) { + delete convertLongToInterpreter(env, interpreter_handle); + } + if (model_handle != 0) { + delete convertLongToModel(env, model_handle); + } + if (error_handle != 0) { + delete convertLongToErrorReporter(env, error_handle); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h new file mode 100644 index 0000000000..430886b7cc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +// This is to be provided at link-time by a library. +extern std::unique_ptr CreateOpResolver(); +} // namespace tflite + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JZ) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/Object;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass clazz, jobject model_buffer, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JII)[I + * + * It gets input dimensions if num_bytes matches number of bytes required by + * the input, else returns null and throws IllegalArgumentException. + */ +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJI[I) + * + * It resizes dimensions of a input. + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJJ) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc new file mode 100644 index 0000000000..65126e78a3 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -0,0 +1,242 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include +#include +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +namespace { + +TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to TfLiteTensor."); + return nullptr; + } + return reinterpret_cast(handle); +} + +size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, + void* dst, size_t dst_size) { + jarray array = static_cast(object); + const int num_elements = env->GetArrayLength(array); + size_t to_copy = num_elements * elementByteSize(type); + if (to_copy > dst_size) { + throwException(env, kIllegalStateException, + "cannot write Java array of %d bytes to Tensor of %d bytes", + to_copy, dst_size); + return 0; + } + switch (type) { + case kTfLiteFloat32: { + jfloatArray a = static_cast(array); + jfloat* values = env->GetFloatArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt32: { + jintArray a = static_cast(array); + jint* values = env->GetIntArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseIntArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt64: { + jlongArray a = static_cast(array); + jlong* values = env->GetLongArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseLongArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteUInt8: { + jbyteArray a = static_cast(array); + jbyte* values = env->GetByteArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseByteArrayElements(a, values, JNI_ABORT); + return to_copy; + } + default: { + throwException(env, kUnsupportedOperationException, + "TensorFlowLite currently supports float (32 bits), " + "int (32 bits), byte (8 bits), and long (64 bits), " + "support for other types (DataType %d in this case) will " + "be added in the future", + kTfLiteFloat32, type); + return 0; + } + } +} + +size_t readOneDimensionalArray(JNIEnv* env, TfLiteType data_type, + const void* src, size_t src_size, jarray dst) { + const int len = env->GetArrayLength(dst); + const size_t size = len * elementByteSize(data_type); + if (size > src_size) { + throwException( + env, kIllegalStateException, + "cannot fill a Java array of %d bytes with a Tensor of %d bytes", size, + src_size); + return 0; + } + switch (data_type) { + case kTfLiteFloat32: { + jfloatArray float_array = static_cast(dst); + env->SetFloatArrayRegion(float_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteInt32: { + jintArray int_array = static_cast(dst); + env->SetIntArrayRegion(int_array, 0, len, static_cast(src)); + return size; + } + case kTfLiteInt64: { + jlongArray long_array = static_cast(dst); + env->SetLongArrayRegion(long_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteUInt8: { + jbyteArray byte_array = static_cast(dst); + env->SetByteArrayRegion(byte_array, 0, len, + static_cast(src)); + return size; + } + default: { + throwException(env, kIllegalStateException, "invalid DataType(%d)", + data_type); + } + } + return 0; +} + +size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, + size_t src_size, int dims_left, jarray dst) { + if (dims_left == 1) { + return readOneDimensionalArray(env, data_type, src, src_size, dst); + } else { + jobjectArray ndarray = static_cast(dst); + int len = env->GetArrayLength(ndarray); + size_t size = 0; + for (int i = 0; i < len; ++i) { + jarray row = static_cast(env->GetObjectArrayElement(ndarray, i)); + size += readMultiDimensionalArray(env, data_type, src + size, + src_size - size, dims_left - 1, row); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return size; + } + return size; + } +} + +} // namespace + +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Java float not compatible with kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Java byte not compatible with kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Java long not compatible with kTfLiteInt64"); + return 8; + default: + return 0; + } +} + +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { + char* buf = static_cast(env->GetDirectBufferAddress(object)); + if (!buf) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return 0; + } + *dst = buf; + return dst_size; +} + +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size) { + if (dims_left <= 1) { + return writeOneDimensionalArray(env, src, type, *dst, dst_size); + } else { + jobjectArray ndarray = static_cast(src); + int len = env->GetArrayLength(ndarray); + size_t sz = 0; + for (int i = 0; i < len; ++i) { + jobject row = env->GetObjectArrayElement(ndarray, i); + char* next_dst = *dst + sz; + sz += writeMultiDimensionalArray(env, row, type, dims_left - 1, &next_dst, + dst_size - sz); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return sz; + } + return sz; + } +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + int num_dims = tensor->dims->size; + if (num_dims == 0) { + throwException(env, kIllegalArgumentException, + "copyTo() is not meant for scalar Tensors."); + return; + } + readMultiDimensionalArray(env, tensor->type, tensor->data.raw, tensor->bytes, + num_dims, static_cast(value)); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast(tensor->type); +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + int num_dims = tensor->dims->size; + jintArray result = env->NewIntArray(num_dims); + jint* dims = env->GetIntArrayElements(result, nullptr); + for (int i = 0; i < num_dims; ++i) { + dims[i] = static_cast(tensor->dims->data[i]); + } + env->ReleaseIntArrayElements(result, dims, 0); + return result; +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h new file mode 100644 index 0000000000..3a4910dcc3 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)[I + */ +JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (JLjava/lang/Object;) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value); + +/* + * Finds the size of each data type. + */ +size_t elementByteSize(TfLiteType data_type); + +/* + * Writes data of a ByteBuffer into dest. + */ +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); + +/* + * Writes a multi-dimensional array into dest. + */ +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc new file mode 100644 index 0000000000..2e7f2f5692 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h" +#include "tensorflow/contrib/lite/version.h" + +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) { + char buf[64]; + snprintf(buf, sizeof(buf), "%d", TFLITE_SCHEMA_VERSION); + return env->NewStringUTF(buf); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h new file mode 100644 index 0000000000..65f8341149 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TensorFlowLite + * Method: version + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/contrib/lite/java/src/main/native/version_script.lds new file mode 100644 index 0000000000..38c93dda73 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Export JNI symbols. + global: + Java_*; + JNI_OnLoad; + JNI_OnUnload; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java new file mode 100644 index 0000000000..cebc944200 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.DataType}. */ +@RunWith(JUnit4.class) +public final class DataTypeTest { + + @Test + public void testElemByteSize() { + assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.INT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1); + assertThat(DataType.INT64.elemByteSize()).isEqualTo(8); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java new file mode 100644 index 0000000000..a60c63d4b8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Interpreter}. */ +@RunWith(JUnit4.class) +public final class InterpreterTest { + + private static final File MODEL_FILE = + new File("third_party/tensorflow/contrib/lite/java/src/testdata/add.bin"); + + private static final File MOBILENET_MODEL_FILE = + new File("third_party/tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + + @Test + public void testInterpreter() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + assertThat(interpreter).isNotNull(); + interpreter.close(); + } + + @Test + public void testRunWithMappedByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + MappedByteBuffer mappedByteBuffer = + fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); + Interpreter interpreter = new Interpreter(mappedByteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRun() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + Float[] oneD = {1.23f, 6.54f, 7.81f}; + Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Float[][][][] fourD = {threeD, threeD}; + Float[][][][] parsedOutputs = new Float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;"); + } + interpreter.close(); + } + + @Test + public void testRunWithBoxedInputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testRunForMultipleInputsOutputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + interpreter.runForMultipleInputsOutputs(inputs, outputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testMobilenetRun() { + // Create a gray image. + float[][][][] img = new float[1][224][224][3]; + for (int i = 0; i < 224; ++i) { + for (int j = 0; j < 224; ++j) { + img[0][i][j][0] = 0.5f; + img[0][i][j][1] = 0.5f; + img[0][i][j][2] = 0.5f; + } + } + + // Allocate memory to receive the output values. + float[][] labels = new float[1][1001]; + + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + interpreter.run(img, labels); + interpreter.close(); + + assertThat(labels[0]) + .usingExactEquality() + .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); + } + + @Test + public void testRunWithWrongInputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + interpreter.close(); + } + + @Test + public void testRunWithWrongOutputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); + } + interpreter.close(); + } + + @Test + public void testGetInputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getInputIndex("WrongInputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongInputName is not a valid name for any input. The indexes of the inputs" + + " are {input=0}"); + } + int index = interpreter.getInputIndex("input"); + assertThat(index).isEqualTo(0); + } + + @Test + public void testGetOutputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getOutputIndex("WrongOutputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongOutputName is not a valid name for any output. The indexes of the outputs" + + " are {MobilenetV1/Predictions/Softmax=0}"); + } + int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); + assertThat(index).isEqualTo(0); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java new file mode 100644 index 0000000000..9e4724b8e9 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -0,0 +1,406 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */ +@RunWith(JUnit4.class) +public final class NativeInterpreterWrapperTest { + + private static final String FLOAT_MODEL_PATH = + "third_party/tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private static final String INT_MODEL_PATH = + "third_party/tensorflow/contrib/lite/java/src/testdata/int32.bin"; + + private static final String LONG_MODEL_PATH = + "third_party/tensorflow/contrib/lite/java/src/testdata/int64.bin"; + + private static final String BYTE_MODEL_PATH = + "third_party/tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + + private static final String INVALID_MODEL_PATH = + "third_party/tensorflow/contrib/lite/java/src/testdata/invalid.model.tflite"; + + @Test + public void testConstructor() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper).isNotNull(); + wrapper.close(); + } + + @Test + public void testConstructorWithInvalidModel() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + } + } + + @Test + public void testRunWithFloat() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithInt() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); + int[] oneD = {3, 7, -4}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + int[][][][] parsedOutputs = new int[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + int[] outputOneD = parsedOutputs[0][0][0]; + int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithLong() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); + long[] oneD = {-892834092L, 923423L, 2123918239018L}; + long[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + long[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + long[][][][] parsedOutputs = new long[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + long[] outputOneD = parsedOutputs[0][0][0]; + long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, + -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByte() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + byte[] oneD = {(byte) 0xe0, 0x4f, (byte) 0xd0}; + byte[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + byte[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + byte[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingBytes() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 8 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.put((byte) 0xe0); + bbuf.put((byte) 0x4f); + bbuf.put((byte) 0xd0); + } + } + } + Object[] inputs = {bbuf}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = { + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 + }; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingFloats() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(4 * 8 * 8 * 3 * 4); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.putFloat(1.23f); + bbuf.putFloat(-6.54f); + bbuf.putFloat(7.81f); + } + } + } + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + } + int[] inputDims = {4, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[4][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingWrongSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputType() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + wrapper.close(); + } + + @Test + public void testRunAfterClose() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + wrapper.close(); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); + } + } + + @Test + public void testRunWithEmptyInputs() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + try { + Object[] inputs = {}; + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Invalid inputs. Inputs should not be null or empty."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD, fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputNumOfDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Object[] inputs = {threeD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + } + wrapper.close(); + } + + @Test + public void testNumElements() { + int[] shape = {2, 3, 4}; + int num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(24); + shape = null; + num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(0); + } + + @Test + public void testIsNonEmtpyArray() { + assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); + assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); + int[] emptyArray = {}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); + int[] validArray = {9, 5, 2, 1}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); + } + + @Test + public void testDataTypeOf() { + float[] testEmtpyArray = {}; + DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + NativeInterpreterWrapper.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); + } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + NativeInterpreterWrapper.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + NativeInterpreterWrapper.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = NativeInterpreterWrapper.numDimensions(array); + int[] shape = new int[num]; + NativeInterpreterWrapper.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java new file mode 100644 index 0000000000..665c937cb6 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.TensorFlowLite}. */ +@RunWith(JUnit4.class) +public final class TensorFlowLiteTest { + + @Test + public void testVersion() { + assertThat(TensorFlowLite.version()).isEqualTo("3"); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java new file mode 100644 index 0000000000..e41e971159 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Tensor}. */ +@RunWith(JUnit4.class) +public final class TensorTest { + + private static final String MODEL_PATH = + "third_party/tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private NativeInterpreterWrapper wrapper; + private long nativeHandle; + + @Before + public void setUp() { + wrapper = new NativeInterpreterWrapper(MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + nativeHandle = outputs[0].nativeHandle; + } + + @After + public void tearDown() { + wrapper.close(); + } + + @Test + public void testFromHandle() throws Exception { + Tensor tensor = Tensor.fromHandle(nativeHandle); + assertThat(tensor).isNotNull(); + int[] expectedShape = {2, 8, 8, 3}; + assertThat(tensor.shapeCopy).isEqualTo(expectedShape); + assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + } + + @Test + public void testCopyTo() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[2][8][8][3]; + tensor.copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToWrongType() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " + + "type INT32)"); + } + } + + @Test + public void testCopyToWrongShape() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[1][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Shape of output target [1, 8, 8, 3] does not match " + + "with the shape of the Tensor [2, 8, 8, 3]."); + } + } +} diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD new file mode 100644 index 0000000000..2b4f37bc6c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -0,0 +1,30 @@ +# Description: +# Internal helper function to test TF Lite API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +android_library( + name = "testhelper", + srcs = glob( + [ + "*.java", + ], + ), + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite_java", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java new file mode 100644 index 0000000000..8660cabf70 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** A helper class for internal tests. */ +public class TestHelper { + + /** + * Turns on/off NNAPI of an {@code Interpreter}. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param useNNAPI a boolean value indicating to turn on or off NNAPI. + */ + public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) { + if (interpreter != null && interpreter.wrapper != null) { + interpreter.wrapper.setUseNNAPI(useNNAPI); + } else { + throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); + } + } +} diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD new file mode 100644 index 0000000000..bbbfa3e741 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -0,0 +1,408 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "optional_tensor_test", + size = "small", + srcs = ["optional_tensor_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "gemm_support", + srcs = [ + "gemm_support.cc", + ], + hdrs = [ + "gemm_support.h", + ], + copts = tflite_copts(), + deps = [ + ":op_macros", + "//tensorflow/contrib/lite:context", + "@gemmlowp//:gemmlowp", + ], +) + +cc_library( + name = "activation_functor", + hdrs = [ + "activation_functor.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + ], +) + +cc_library( + name = "op_macros", + hdrs = [ + "op_macros.h", + ], +) + +cc_library( + name = "builtin_ops", + srcs = [ + "activations.cc", + "add.cc", + "basic_rnn.cc", + "concatenation.cc", + "conv.cc", + "depthwise_conv.cc", + "embedding_lookup.cc", + "embedding_lookup_sparse.cc", + "fully_connected.cc", + "hashtable_lookup.cc", + "kernel_util.cc", + "l2norm.cc", + "local_response_norm.cc", + "lsh_projection.cc", + "lstm.cc", + "mul.cc", + "pooling.cc", + "register.cc", + "reshape.cc", + "resize_bilinear.cc", + "skip_gram.cc", + "space_to_depth.cc", + "svdf.cc", + ], + hdrs = [ + "kernel_util.h", + "padding.h", + "register.h", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":activation_functor", + ":op_macros", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:round", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@farmhash_archive//:farmhash", + ], +) + +tf_cc_test( + name = "activations_test", + size = "small", + srcs = ["activations_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "add_test", + size = "small", + srcs = ["add_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "concatenation_test", + size = "small", + srcs = ["concatenation_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "conv_test", + size = "small", + srcs = ["conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "depthwise_conv_test", + size = "small", + srcs = ["depthwise_conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "basic_rnn_test", + size = "small", + srcs = ["basic_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "l2norm_test", + size = "small", + srcs = ["l2norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "mul_test", + size = "small", + srcs = ["mul_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "reshape_test", + size = "small", + srcs = ["reshape_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "resize_bilinear_test", + size = "small", + srcs = ["resize_bilinear_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "svdf_test", + size = "small", + srcs = ["svdf_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_test", + size = "small", + srcs = ["embedding_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_sparse_test", + size = "small", + srcs = ["embedding_lookup_sparse_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "fully_connected_test", + size = "small", + srcs = ["fully_connected_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "local_response_norm_test", + size = "small", + srcs = ["local_response_norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pooling_test", + size = "small", + srcs = ["pooling_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "softmax_test", + size = "small", + srcs = ["softmax_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lsh_projection_test", + size = "small", + srcs = ["lsh_projection_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "hashtable_lookup_test", + size = "small", + srcs = ["hashtable_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lstm_test", + size = "small", + srcs = ["lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "skip_gram_test", + size = "small", + srcs = ["skip_gram_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "space_to_depth_test", + size = "small", + srcs = ["space_to_depth_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h new file mode 100644 index 0000000000..cfb3369e99 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { + +// Dynamic (non-fused) activation functor. perhaps it is worth having +// template instantiation? +// TODO(aselle): Make this more efficient by pulling the switch to conv_eval +// using template inlining. +class ActivationFunctor { + public: + explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {} + + float operator()(float a) const { + switch (act_) { + case kTfLiteActNone: + return a; + case kTfLiteActRelu: + return a < 0.f ? 0.f : a; + case kTfLiteActRelu6: + return std::max(0.f, std::min(a, 6.f)); + case kTfLiteActTanh: + return std::tanh(a); + case kTfLiteActSigmoid: + return 1.0f / (1.0f + std::exp(-a)); + default: + // TODO(aselle): More informative fatal error! + exit(1); + } + } + + private: + TfLiteFusedActivation act_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc new file mode 100644 index 0000000000..7ab60a33e5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -0,0 +1,389 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace activations { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static constexpr int kInputIntegerBits = 4; + + const double input_real_multiplier = + input->params.scale * + static_cast(1 << (31 - kInputIntegerBits)); + + QuantizeMultiplierGreaterThanOne(input_real_multiplier, + &data->input_multiplier, + &data->input_left_shift); + data->input_range_radius = + CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TF_LITE_ENSURE(context, + NumDimensions(input) == 2 || NumDimensions(input) == 4); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + params->beta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::max(0.f, *in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) { + *out = std::min(std::max(-1.f, *in), 1.f); + } + return kTfLiteOk; + } break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::tanh(*in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +// Sigmoid is also know as "Logistic". +TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); + break; + } + case kTfLiteUInt8: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorDims(input), + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, + GetTensorData(output), GetTensorDims(output)); + break; + } + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Takes a 2D tensor and perform softmax along the second dimension. +void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + float* in = input->data.f; + float* out = output->data.f; + TF_LITE_ASSERT(input_size > 0); + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Find the max coeff. + float max_coeff = in[0]; + for (int i = 1; i < input_size; i++) { + if (in[i] > max_coeff) max_coeff = in[i]; + } + + // Compute the normalized sum of exps. + float exp_sum = 0.0; + for (int i = 0; i < input_size; i++) { + out[i] = std::exp((in[i] - max_coeff) * params->beta); + exp_sum += out[i]; + } + + // Divide by the sum of exps. + float reciprocal_sum_exp = 1.f / exp_sum; + for (int i = 0; i < input_size; i++) { + out[i] *= reciprocal_sum_exp; + } + + // Advance in and out pointers for the next batch. + in += input_size; + out += input_size; + } +} + +void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + optimized_ops::Softmax(GetTensorData(input), + GetTensorDims({batch_size, 1, 1, input_size}), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims({batch_size, 1, 1, input_size})); +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + params->beta, GetTensorData(output), + GetTensorDims(output)); +} + +void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims(output)); +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + // TODO(ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 2) { + Softmax2DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + case kTfLiteUInt8: { + if (NumDimensions(input) == 2) { + Softmax2DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + default: + context->ReportError(context, + "Only float32 and uint8_t supported currently."); + return kTfLiteError; + } +} + +} // namespace activations + +TfLiteRegistration* Register_RELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::ReluEval}; + return &r; +} + +TfLiteRegistration* Register_RELU1() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu1Eval}; + return &r; +} + +TfLiteRegistration* Register_RELU6() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu6Eval}; + return &r; +} + +TfLiteRegistration* Register_TANH() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::TanhEval}; + return &r; +} + +TfLiteRegistration* Register_LOGISTIC() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SigmoidPrepare, + activations::SigmoidEval}; + return &r; +} + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc new file mode 100644 index 0000000000..f10aee7017 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseActivationsOpModel : public SingleOpModel { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + // A dedicated constructor for SOFTMAX, which does some options. + BaseActivationsOpModel(float softmax_beta, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, softmax_beta).Union()); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// TODO(ahentz): I don't quite understand the tradeoffs in the quantized +// implementation of sigmoid and software, but a tolerance of twice the output +// scale seems reasonable. We might want to change this if we have a better +// theoretical bound. +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatActivationsOpTest, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(FloatActivationsOpTest, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(FloatActivationsOpTest, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(FloatActivationsOpTest, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(FloatActivationsOpTest, Sigmoid) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(QuantizedActivationsOpTest, Sigmoid) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +TEST(FloatActivationsOpTest, Softmax4D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax4D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +TEST(FloatActivationsOpTest, Softmax2D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax2D) { + QuantizedActivationsOpModel m(0.1, + /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1, + /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc new file mode 100644 index 0000000000..0e10a249ab --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace add { + +// This file has three implementation of Add. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_ADD(type) \ + type::Add(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + const int left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / ((1 << left_shift) * output->params.scale); + + int32 input1_multiplier; + int input1_shift; + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, + &input1_shift); + int32 input2_multiplier; + int input2_shift; + QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, + &input2_shift); + int32 output_multiplier; + int output_shift; + QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_ADD(type) \ + type::BroadcastAdd( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), input2_offset, \ + input2_multiplier, input2_shift, output_offset, output_multiplier, \ + output_shift, output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalAddFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalAddQuantized(context, node, params, input1, input2, + output); + } else { + context->ReportError(context, + "Inputs and outputs not all float|unit8 types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace add + +TfLiteRegistration* Register_ADD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD() { +#ifdef USE_NEON + return Register_ADD_NEON_OPT(); +#else + return Register_ADD_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc new file mode 100644 index 0000000000..8e12a837c4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseAddOpModel : public SingleOpModel { + public: + BaseAddOpModel(const TensorData& input, const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// for quantized Add, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(FloatAddOpModel, NoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +TEST(FloatAddOpModel, ActivationRELU1) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0})); +} + +TEST(FloatAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1})) + << "With shape number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, + {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = {{0.6, 0.4, 0.9, -0.8}, + {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = {{-0.2, 0.6, 1.0, -0.1}, + {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_RELU1); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + +} // namespace +} // namespace tflite +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc new file mode 100644 index 0000000000..3cee43c68b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int KHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, + output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[1]; + const int input_weights_stride = input_weights->dims->data[1]; + const int recurrent_weights_stride = recurrent_weights->dims->data[1]; + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to input, output and bias. + const float* input_ptr_batch = input->data.f + b * input_size; + float* output_ptr_batch = output->data.f + b * num_units; + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = + (ActivationFunctor(params->activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } + } + + return kTfLiteOk; +} + +} // namespace rnn + +TfLiteRegistration* Register_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + rnn::Prepare, rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc new file mode 100644 index 0000000000..dfa75655bc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -0,0 +1,267 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0 +}; + +class RNNOpModel : public SingleOpModel { + public: + RNNOpModel(int batches, int units, int size) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(FullyConnectedOpTest, BlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc new file mode 100644 index 0000000000..9e7a1233da --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace concatenation { + +// This file has two implementation of Concatenation. +enum KernelType { + kReference, + kGenericOptimized, +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + int axis = params->axis; + int num_inputs = node->inputs->size; + + // The number of dimensions of the input tensors must match, and all + // dimensions except 'axis' must be equal. + TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; + TfLiteType input_type = t0->type; + TF_LITE_ENSURE(context, axis >= 0); + TF_LITE_ENSURE(context, axis < t0->dims->size); + + // TODO(ahentz): These are limitations of our implementation that could be + // removed with a bit of effort. + TF_LITE_ENSURE(context, t0->dims->size <= 4); + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + + // Output dimensions will match input dimensions, except 'axis', which + // will be the sum of inputs + int sum_axis = t0->dims->data[axis]; + for (int i = 1; i < num_inputs; ++i) { + TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; + TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); + TF_LITE_ENSURE_EQ(context, t->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale); + } + for (int d = 0; d < t0->dims->size; ++d) { + if (d == axis) { + sum_axis += t->dims->data[axis]; + } else { + TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); + } + } + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); + for (int d = 0; d < t0->dims->size; ++d) { + output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; + } + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_EQ(context, output->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +class VectorOfInputs { + public: + VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { + int num_inputs = inputs.size; + + all_data_.reserve(num_inputs); + all_dims_.reserve(num_inputs); + all_dims_ptr_.reserve(num_inputs); + + for (int i = 0; i < num_inputs; ++i) { + TfLiteTensor* input = &context.tensors[inputs.data[i]]; + all_data_.push_back(GetTensorData(input)); + all_dims_.push_back(GetTensorDims(input)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_inputs; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + const T* const* data() const { return all_data_.data(); } + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + +// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should +// allocate and populate these during Prepare(). +// TODO(ycling): Activation function parameter is ignored. For now we dont have +// a model with a Concatenation with fused activation function. +#define TF_LITE_CONCATENATION(type, scalar) \ + VectorOfInputs all_inputs(*context, *node->inputs); \ + type::Concatenation( \ + RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + all_inputs.dims(), node->inputs->size, GetTensorData(output), \ + GetTensorDims(output)) + + switch (output->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, float); + } else { + TF_LITE_CONCATENATION(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, uint8_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, uint8_t); + } + break; + default: + context->ReportError(context, + "Only float32 and uint8 are currently supported."); + return kTfLiteError; + } + +#undef TF_LITE_CONCATENATION + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +} // namespace concatenation + +TfLiteRegistration* Register_CONCATENATION_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION() { + // TODO(ahentz): It turns out the two versions of Concatenation are almost + // identical, so we should consider removing one. + return Register_CONCATENATION_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc new file mode 100644 index 0000000000..94e5b2acdc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConcatenationOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, axis, input + // dimensions. + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(ConcatenationOpTest, OneTrivialInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {5.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); +} + +TEST(ConcatenationOpTest, TwoDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(ConcatenationOpTest, TwoInputsTwoAxis) { + // We will concatenate two tensors along different dimensions. + auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/2); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, + /*num_inputs=*/2); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); + m1.Invoke(); + EXPECT_THAT(m1.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + +TEST(ConcatenationOpTest, FourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +TEST(ConcatenationOpTest, FourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc new file mode 100644 index 0000000000..c75c04baea --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -0,0 +1,425 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace conv { + +// This file has three implementation of Conv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + // IDs are the arbitrary identifiers used by TF Lite to identify and access + // memory buffers. + int im2col_id; + int hwcn_weights_id; + + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // Indexes are the offset to the memory buffer in the array used to keep track + // of the allocated temporaries. + int32_t im2col_index; + int32_t hwcn_weights_index; + bool need_hwcn_weights; + bool have_weights_been_transposed; + bool need_im2col; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to use as scratch space for im2col, and + // to carry information from Prepare() to Eval(). + auto* data = new OpData; + context->AddTensors(context, 1, &data->im2col_id); + context->AddTensors(context, 1, &data->hwcn_weights_id); + gemm_support::IncrementUsageCounter(context); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +// Naive implementation of transpose for floats. Could be optimized to be more +// cache friendly, but for now it's a one-time cost on first run, and we would +// prefer to remove the need to do this at all eventually. +void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) { + const int rows = output->dims->data[1]; + const int cols = output->dims->data[0]; + const float* input_data = GetTensorData(input); + float* output_data = GetTensorData(output); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + const float in_value = input_data[i * cols + j]; + output_data[j * rows + i] = in_value; + } + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + bool hasBias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + // Check dimensionality of input, filter + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + TF_LITE_ENSURE_EQ(context, filter->dims->size, 4); + // Check input channels matching filter + TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]); + + // Check types. (We assume that UINT8 refers to quantized tensors) + TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + TfLiteTensor* bias = nullptr; + + // TODO(ahentz): At this point the optimized versions require 'bias'. We can + // either change that or document that convolution requires it. + TF_LITE_ENSURE(context, hasBias); + + if (hasBias) { + bias = &context->tensors[node->inputs->data[2]]; + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]); + } + + int channels_out = filter->dims->data[0]; + int width = input->dims->data[2]; + int height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int batches = input->dims->data[0]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = computeOutSize(width, filter_width, params->stride_width); + int outHeight = computeOutSize(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, outHeight); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, outWidth); + + TF_LITE_ENSURE(context, hasBias); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = outHeight; + output_size->data[2] = outWidth; + output_size->data[3] = channels_out; + auto output_status = context->ResizeTensor(context, output, output_size); + + if (output_status != kTfLiteOk) return output_status; + + // We don't always need to allocate im2col. It is only used in some versions + // of the optimized Conv. This test just mimics something that happens inside + // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). + data->need_im2col = + (params->stride_width != 1 || params->stride_height != 1 || + filter_width != 1 || filter_height != 1); + // If we're using the optimized multithreaded EigenTensor implementation of + // convolution, it expects the filter weights to be transposed compared to + // the normal TF Lite buffer format. Typical TF Lite weights are + // [filter_count, filter_height, filter_width, input_depth], but for the float + // implementation we need them as [filter_height, filter_width, input_depth, + // filter_count]. We get to that format by transposing, and create a temporary + // buffer to store the results. + // This path is only used for float processing, so only create the buffer if + // we're running with that data type. + data->need_hwcn_weights = (data_type == kTfLiteFloat32); + + int temporaries_count = 0; + if (data->need_im2col) { + data->im2col_index = temporaries_count; + ++temporaries_count; + } + if (data->need_hwcn_weights) { + data->hwcn_weights_index = temporaries_count; + ++temporaries_count; + } + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(temporaries_count); + + if (data->need_im2col) { + node->temporaries->data[data->im2col_index] = data->im2col_id; + + TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(4); + + int input_depth = input->dims->data[3]; + im2col_size->data[0] = output_size->data[0]; + im2col_size->data[1] = output_size->data[1]; + im2col_size->data[2] = output_size->data[2]; + im2col_size->data[3] = input_depth * filter_height * filter_width; + + TfLiteTensor* im2col = + &context->tensors[node->temporaries->data[data->im2col_index]]; + im2col->type = data_type; + im2col->allocation_type = kTfLiteArenaRw; + auto im2col_status = context->ResizeTensor(context, im2col, im2col_size); + if (im2col_status != kTfLiteOk) return im2col_status; + } + + if (data->need_hwcn_weights) { + node->temporaries->data[data->hwcn_weights_index] = data->hwcn_weights_id; + TfLiteIntArray* hwcn_weights_size = TfLiteIntArrayCreate(2); + + // Because we're treating the filter weights as a matrix when we do the + // transpose, we allocate the buffer with a two-dimensional shape, where one + // dimension is the number of elements in each filter, and the second is the + // total number of filters. + int input_depth = input->dims->data[3]; + hwcn_weights_size->data[0] = (filter_height * filter_width * input_depth); + hwcn_weights_size->data[1] = channels_out; + + TfLiteTensor* hwcn_weights = + &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; + hwcn_weights->type = data_type; + hwcn_weights->allocation_type = kTfLiteDynamic; + // Make sure we release any previous allocations before we reallocate. + // TODO(petewarden): Persistent arenas would be a better fit for this, but + // they aren't fully implemented yet. + if (hwcn_weights->data.raw) { + free(hwcn_weights->data.raw); + hwcn_weights->data.raw = nullptr; + } + auto hwcn_weights_status = + context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); + if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; + hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); + + // TODO(petewarden): If Resize() is called when the size hasn't actually + // changed, this will do extra redundant work. + data->have_weights_been_transposed = false; + } + + return kTfLiteOk; +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, + TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } else { + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + } else { + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, params->padding, output_activation_min, + output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col)); + } +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + bool hasBias = node->inputs->size == 3; + TfLiteTensor* bias = + hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + TfLiteTensor* im2col = + data->need_im2col + ? &context->tensors[node->temporaries->data[data->im2col_index]] + : nullptr; + TfLiteTensor* hwcn_weights = + data->need_hwcn_weights + ? &context->tensors[node->temporaries->data[data->hwcn_weights_index]] + : nullptr; + + if (data->need_hwcn_weights && !data->have_weights_been_transposed) { + TransposeFloatTensor(filter, hwcn_weights); + data->have_weights_been_transposed = true; + } + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + im2col, hwcn_weights, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONVOLUTION_REF() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONV_2D() { +#ifdef USE_NEON + return Register_CONVOLUTION_NEON_OPT(); +#else + return Register_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc new file mode 100644 index 0000000000..18d7a31d59 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -0,0 +1,440 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions(builder_, padding, stride_width, + stride_height, activation) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 30, -24, // + 40, -34, // + })); +} + +TEST(ConvolutionOpTest, HandCalculatedFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | 10 |. + m.SetBias({10}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131 + // This means we should end up with this matrix: + // | 115 | 160 | 193 | 105 | + // | 245 | 322 | 367 | 188 | + // | 197 | 244 | 271 | 131 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322, + 367, 188, 197, 244, 271, 131})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, + ActivationFunctionType_RELU); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | -200 |. + m.SetBias({-200}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79 + // All negative values are gated to zero by the Relu activation function. + // This means we should end up with this matrix: + // | 0 | 0 | 0 | 0 | + // | 35 | 112 | 157 | 0 | + // | 0 | 34 | 61 | 0 | + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); +} + +TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_VALID; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside + // the input because we're using the 'VALID' padding mode, giving a 2x1 + // output. + // The calculations behind the expected output are: + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // This means we should end up with this matrix: + // | 312 | 357 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357})); +} + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 30, -24, // + 40, -34, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 157, 103, // + 167, 93, // + })); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc new file mode 100644 index 0000000000..15dbfe08c8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace depthwise_conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// This file has three implementation of DepthwiseConv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // TODO(ahentz): use could use GetOptionalInputTensor() here, but we need to + // decide whether we are OK with optional tensors being completely absent, as + // opposed to having -1 as their index. + bool hasBias = NumInputs(node) == 3; + + TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = nullptr; + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4); + + // The parameter 'depth_multiplier' is redundant, so we check here to make + // sure it is consistent with the given dimensions. + TF_LITE_ENSURE_EQ(context, + params->depth_multiplier * SizeOfDimension(input, 3), + SizeOfDimension(filter, 3)); + + const TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + if (hasBias) { + bias = GetInput(context, node, kBiasTensor); + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3), + SizeOfDimension(bias, 0)); + } + + int channels_out = SizeOfDimension(filter, 3); + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int batches = SizeOfDimension(input, 0); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto compute_out_size = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int out_width = compute_out_size(width, filter_width, params->stride_width); + int out_height = + compute_out_size(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = out_height; + outputSize->data[2] = out_width; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + void (*depthwise_conv)(const float*, const Dims<4>&, const float*, + const Dims<4>&, const float*, const Dims<4>&, int, int, + int, int, int, float, float, float*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output)); +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*, + const Dims<4>&, int32, const int32*, const Dims<4>&, + int, int, int, int, int, int32, int32, int, int32, + int32, uint8*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output)); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { +#ifdef USE_NEON + return Register_DEPTHWISE_CONVOLUTION_NEON_OPT(); +#else + return Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc new file mode 100644 index 0000000000..39227b2811 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseDepthwiseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseDepthwiseConvolutionOpModel(const TensorData& input, + const TensorData& filter, + const TensorData& output) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(DepthwiseConvolutionOpTest, SimpleTest) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class QuantizedDepthwiseConvolutionOpModel + : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this test we set the input and output scales so that the results match +// exactly the 'non-quantized' version. +TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { + QuantizedDepthwiseConvolutionOpModel m( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 71, -34, 99, -20, // + 91, -26, 127, -4, // + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 198, 93, 226, 107, // + 218, 101, 254, 123, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc new file mode 100644 index 0000000000..4e8cb396d4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops that looks up items from matrix. +// +// Input: +// Tensor[0]: Row number to lookup, dim.size == 1, int32 +// Tensor[1]: 2-dimensional matrix of multi-dimensional items +// dim.size >= 2, any data type. +// first dimension is row, second dimension is column. +// +// Output: +// Output.dim[0] == Tensor[0].dim[0], num of lookups +// Output.dim[1] == Tensor[1].dim[1], num of items per row +// Each item in output is a raw bytes copy of corresponding item in input. +// When indices are out of bound, the ops will not succeed. +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace embedding_lookup { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + + outputSize->data[0] = SizeOfDimension(lookup, 0); + outputSize->data[1] = SizeOfDimension(value, 1); + for (int i = 2; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + return context->ResizeTensor(context, output, outputSize); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* value = GetInput(context, node, 1); + + const int row_size = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / row_size; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + } + + return kTfLiteOk; +} + +} // namespace embedding_lookup + +TfLiteRegistration* Register_EMBEDDING_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, + embedding_lookup::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc new file mode 100644 index 0000000000..6c770e7f71 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from a sparse tensor in an embedding matrix. +// The sparse lookup tensor is represented by three individual tensors: lookup, +// indices, and dense_shape. The representation assume that the corresponding +// dense tensor would satisfy: +// * dense.shape = dense_shape +// * dense[tuple(indices[i])] = lookup[i] +// +// By convention, indices should be sorted. +// +// Options: +// combiner: The reduction op (SUM, MEAN, SQRTN). +// * SUM computes the weighted sum of the embedding results. +// * MEAN is the weighted sum divided by the total weight. +// * SQRTN is the weighted sum divided by the square root of the sum of the +// squares of the weights. +// +// Input: +// Tensor[0]: Ids to lookup, dim.size == 1, int32. +// Tensor[1]: Indices, int32. +// Tensor[2]: Dense shape, int32. +// Tensor[3]: Weights to use for aggregation, float. +// Tensor[4]: Params, a matrix of multi-dimensional items, +// dim.size >= 2, float. +// +// Output: +// A (dense) tensor representing the combined embeddings for the sparse ids. +// For each row in the sparse tensor represented by (lookup, indices, shape) +// the op looks up the embeddings for all ids in that row, multiplies them by +// the corresponding weight, and combines these embeddings as specified in the +// last dimension. +// +// Output.dim = [l0, ... , ln-1, e1, ..., em] +// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em] +// +// For instance, if params is a 10x20 matrix and ids, weights are: +// +// [0, 0]: id 1, weight 2.0 +// [0, 1]: id 3, weight 0.5 +// [1, 0]: id 0, weight 1.0 +// [2, 3]: id 1, weight 3.0 +// +// with combiner=MEAN, then the output will be a (3, 20) tensor where: +// +// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) +// output[1, :] = (params[0, :] * 1.0) / 1.0 +// output[2, :] = (params[1, :] * 3.0) / 3.0 +// +// When indices are out of bound, the op will not succeed. + +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* ids = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); + TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); + + TfLiteTensor* indices = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); + TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); + + TfLiteTensor* shape = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); + TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); + + TfLiteTensor* weights = GetInput(context, node, 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); + TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); + + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(ids, 0)); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(weights, 0)); + + TfLiteTensor* value = GetInput(context, node, 4); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + // Mark the output as a dynamic tensor. + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + output->allocation_type = kTfLiteDynamic; + + return kTfLiteOk; +} + +void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements, + float current_total_weight, + float current_squares_weight, int embedding_size, + float* output) { + if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) { + float multiplier = 1.0; + switch (combiner) { + case kTfLiteCombinerTypeMean: + multiplier = current_total_weight; + break; + case kTfLiteCombinerTypeSqrtn: + multiplier = std::sqrt(current_squares_weight); + break; + default: + break; + } + for (int k = 0; k < embedding_size; k++) { + output[k] /= multiplier; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* ids = GetInput(context, node, 0); + TfLiteTensor* indices = GetInput(context, node, 1); + TfLiteTensor* dense_shape = GetInput(context, node, 2); + TfLiteTensor* weights = GetInput(context, node, 3); + TfLiteTensor* value = GetInput(context, node, 4); + + const int lookup_rank = SizeOfDimension(indices, 1); + const int embedding_rank = NumDimensions(value); + const int num_lookups = SizeOfDimension(ids, 0); + const int num_rows = SizeOfDimension(value, 0); + + // The last dimension gets replaced by the embedding. + const int output_rank = (lookup_rank - 1) + (embedding_rank - 1); + + // Make sure that the actual dense shape of the sparse tensor represented by + // (loopkup, indices, dense_shape) is consistent. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank); + + // Resize output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + int k = 0; + int embedding_size = 1; + int lookup_size = 1; + for (int i = 0; i < lookup_rank - 1; i++, k++) { + const int dim = dense_shape->data.i32[i]; + lookup_size *= dim; + output_shape->data[k] = dim; + } + for (int i = 1; i < embedding_rank; i++, k++) { + const int dim = SizeOfDimension(value, i); + embedding_size *= dim; + output_shape->data[k] = dim; + } + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); + const int output_size = lookup_size * embedding_size; + TfLiteTensorRealloc(output_size * sizeof(float), output); + + tensor_utils::ZeroVector(output->data.f, output_size); + + // Keep track of the current bucket for aggregation/combination. + int current_output_offset = 0; + float current_total_weight = 0.0; + float current_squares_weight = 0.0; + int num_elements = 0; + + for (int i = 0; i < num_lookups; i++) { + int idx = ids->data.i32[i]; + if (idx >= num_rows || idx < 0) { + context->ReportError(context, + "Embedding Lookup Sparse: index out of bounds."); + return kTfLiteError; + } + + // Check where we need to aggregate. + const int example_indices_offset = i * lookup_rank; + int output_bucket = 0; + int stride = 1; + for (int k = (lookup_rank - 1) - 1; k >= 0; k--) { + output_bucket += indices->data.i32[example_indices_offset + k] * stride; + stride *= dense_shape->data.i32[k]; + } + const int output_offset = output_bucket * embedding_size; + + // If we are in a new aggregation bucket and the combiner is not the sum, + // go back and finalize the result of the previous bucket. + if (output_offset != current_output_offset) { + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + // Track next bucket. + num_elements = 0; + current_total_weight = 0.0; + current_squares_weight = 0.0; + current_output_offset = output_offset; + } + + // Add element to aggregation. + ++num_elements; + const int example_embedding_offset = idx * embedding_size; + const float w = weights->data.f[i]; + current_squares_weight += w * w; + current_total_weight += w; + for (int k = 0; k < embedding_size; k++) { + output->data.f[current_output_offset + k] += + (value->data.f[example_embedding_offset + k] * w); + } + } + + // Finalize last bucket. + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc new file mode 100644 index 0000000000..69d9c5cc7d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite sparse lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupSparseOpModel : public SingleOpModel { + public: + EmbeddingLookupSparseOpModel(CombinerType type, + std::initializer_list lookup_shape, + std::initializer_list indices_shape, + std::initializer_list dense_shape_shape, + std::initializer_list value_shape) { + lookup_ = AddInput(TensorType_INT32); + indices_ = AddInput(TensorType_INT32); + dense_shape_ = AddInput(TensorType_INT32); + weights_ = AddInput(TensorType_FLOAT32); + value_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOptions_EmbeddingLookupSparseOptions, + CreateEmbeddingLookupSparseOptions(builder_, type).Union()); + BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape, + lookup_shape, value_shape}); + } + + void SetInput(std::initializer_list lookup_data, + std::initializer_list indices_data, + std::initializer_list dense_shape_data, + std::initializer_list weights_data) { + PopulateTensor(lookup_, lookup_data); + PopulateTensor(indices_, indices_data); + PopulateTensor(dense_shape_, dense_shape_data); + PopulateTensor(weights_, weights_data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int lookup_; + int weights_; + int indices_; + int dense_shape_; + int value_; + int output_; +}; + +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestMean) { + EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { + EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / + std::sqrt( + 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, Indices3DTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2}, + {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60, + 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { +#ifdef OS_LINUX + tflite::LogToStderr(); +#endif + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc new file mode 100644 index 0000000000..8c030b0677 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupOpModel : public SingleOpModel { + public: + EmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int weight_; + int output_; +}; + +// TODO(ahentz): write more tests that exercise the details of the op, such as +// lookup errors and variable input shapes. +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.PopulateTensor(0, {1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc new file mode 100644 index 0000000000..a77fe94e49 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace fully_connected { + +// This file has four implementations of FullyConnected +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, + kPie, // Used by the PIE team +}; + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + gemm_support::IncrementUsageCounter(context); + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + int input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + input_size *= input->dims->data[i]; + } + + const int batch_size = input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + TfLiteType data_type = input->type; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + return kTfLiteOk; +} + +TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + int total_input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + total_input_size *= input->dims->data[i]; + } + + int input_size = filter->dims->data[1]; + const int batch_size = total_input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + // Output = bias if bias tensor exists. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Compute output += weight * input + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter->data.f, num_units, input_size, input->data.f, batch_size, + output->data.f, /*result_stride=*/1); + + // Apply activation function + tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, + params->activation, output->data.f); + + return kTfLiteOk; +} + +#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \ + if (params->activation == kTfLiteActNone) { \ + macro_name(target_namespace, kNone); \ + } \ + if (params->activation == kTfLiteActRelu) { \ + macro_name(target_namespace, kRelu); \ + } \ + if (params->activation == kTfLiteActRelu6) { \ + macro_name(target_namespace, kRelu6); \ + } + +template +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + int32_t input_offset = -input->params.zero_point; + int32_t filter_offset = -filter->params.zero_point; + int32_t output_offset = output->params.zero_point; +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected( \ + GetTensorData(input), GetTensorDims(input), input_offset, \ + GetTensorData(filter), GetTensorDims(filter), filter_offset, \ + GetTensorData(bias), GetTensorDims(bias), output_offset, \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), gemm_context) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + // TODO(ahentz): we don't have a quantized version of the PIE kernels, so + // we just defer to the MINI ones. + TF_LITE_FULLY_CONNECTED(optimized_ops); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +template +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + return EvalPie(context, node, params, data, input, filter, bias, output); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, + bias, output); + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED_REF() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_PIE() { + static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED() { + // TODO(ahentz): We don't have a dedicated quantized version of the PIE + // kernel. For now, the quantized version just defer to the corresponding + // optimized MINI kernel. At some point we will allow different libraries to + // be built with different kernels, but for now we have to pick one here. + return Register_FULLY_CONNECTED_PIE(); +#ifdef USE_NEON + return Register_FULLY_CONNECTED_NEON_OPT(); +#else + return Register_FULLY_CONNECTED_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc new file mode 100644 index 0000000000..112e3f1ba0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -0,0 +1,377 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite FULLY_CONNECTED op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +static float fully_connected_input[] = { + 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653, + 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390, + 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314, + 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550, + 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112, + 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999, + 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142, + 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494, + 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081, + 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552, + 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320, + 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458, + 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115, + 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771, + 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582, + 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962, + 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202, + 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691, + 0.921330, 0.139902}; + +static float fully_connected_golden_output[] = { + 0, 0.0732134, 0, 0, 0, 0.280859, + 0, 0.128927, 0, 0.0777251, 0, 0.270268, + 0.271435, 0.0173503, 0.335465, 0.235562, + + 0, 0.0745866, 0, 0.051611, 0, 0.253876, + 0, 0.0814873, 0, 0.104104, 0, 0.248529, + 0.264194, 0, 0.302973, 0.166252, + + 0, 0.0170409, 0, 0.0509851, 0, 0.212834, + 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428, + 0.298051, 0, 0.332233, 0.00445903, + + 0, 0.125246, 0, 0.0735336, 0, 0.0910256, + 0, 0, 0, 0.18933, 0.378111, 0.0712443, + 0.277298, 0.0123414, 0.267454, 0, + + 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256, + 0, 0, 0, 0.156412, 0.434914, 0.0461529, + 0.246508, 0, 0.363138, 0, + + 0, 0, 0, 0.0212949, 0, 0.301708, + 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195, + 0.197161, 0, 0.37316, 0, + + 0, 0.221783, 0, 0, 0.0116515, 0.281945, + 0, 0, 0, 0, 0.285626, 0.181773, + 0.296401, 0.170452, 0.367135, 0.142597, + + 0, 0, 0, 0, 0, 0.418886, + 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589, + 0.398286, 0.177146, 0.40359, 0.121452, + + 0, 0.0834884, 0, 0, 0, 0.287441, + 0, 0.0046838, 0, 0.0122087, 0, 0.217376, + 0.140183, 0.0948412, 0.436677, 0.0589876, + + 0, 0.0289969, 0, 0.0921397, 0, 0.396802, + 0, 0.0126157, 0, 0.0968433, 0, 0.172271, + 0.173295, 0.0664741, 0.53645, 0.00915603, + + 0, 0, 0, 0, 0, 0.147942, + 0, 0.263795, 0, 0.39782, 0, 0.382435, + 0.561072, 0.0579847, 0.145712, 0.13508, + + 0, 0, 0, 0.16382, 0, 0.322294, + 0, 0.163798, 0, 0.405211, 0.367953, 0.076852, + 0.342473, 0.0834118, 0.377537, 0, + + 0, 0.206, 0, 0, 0, 0.375769, + 0, 0, 0, 0, 0, 0.125165, + 0, 0.105591, 0.52055, 0.0536445, + + 0, 0.259261, 0, 0, 0, 0.247707, + 0, 0, 0, 0, 0, 0.215862, + 0.149153, 0.224678, 0.359519, 0.129419, + + 0, 0.17611, 0, 0.280895, 0, 0.576484, + 0, 0.000418848, 0, 0, 0, 0.151112, + 0.211902, 0, 0.566341, 0.106305, + + 0, 0.0246284, 0, 0, 0, 0.196267, + 0, 0.0248624, 0, 0.265635, 0, 0.436199, + 0.408079, 0.134514, 0.328489, 0.411368}; + +class BaseFullyConnectedOpModel : public SingleOpModel { + public: + // TODO(ahentz): test different activation types too. + BaseFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + void SetWeights(std::initializer_list data) { + QuantizeAndPopulate(weights_, data); + } + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// TODO(ahentz): add more small tests like this one, focused on making sure the +// calculations are correct. +TEST(FullyConnectedOpTest, SimpleTest) { + FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +TEST(FullyConnectedOpTest, SimpleTestQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +TEST(FullyConnectedOpTest, SimpleTest4DInput) { + // Note that it is not required that the first dimension be the number of + // batches. All we care is that the input can be evenly distributed in + // batches. In this case, we need the input to have multiples of '2'. + FloatFullyConnectedOpModel m(/*units=*/3, + /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + })); +} + +TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard +// to debug errors and doesn't necessarily test all the important details. +TEST(FullyConnectedOpTest, BlackBoxTest) { + FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}}); + m.SetWeights( + {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, + -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, + -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307, + -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650, + 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782, + -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638, + -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540, + -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382, + -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824, + -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308, + 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958, + 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863, + -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612, + 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583, + 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284, + 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480, + -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041, + 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764, + -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627, + 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172, + 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324, + 0.145408, -0.221897}); + m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860, + 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804, + 0.048478, -0.032270, 0.175688, -0.085662}); + + const int input_sequence_size = sizeof(fully_connected_input) / + sizeof(float) / + (m.input_size() * m.num_batches()); + for (int i = 0; i < input_sequence_size; i++) { + // TODO(ahentz): This is what the original test was doing: two equal + // batches per invocation. We could instead use two different batches. + float* batch_start = fully_connected_input + i * m.input_size(); + float* batch_end = batch_start + m.input_size(); + m.SetInput(0, batch_start, batch_end); + m.SetInput(m.input_size(), batch_start, batch_end); + + m.Invoke(); + + float* golden_start = fully_connected_golden_output + i * m.num_units(); + float* golden_end = golden_start + m.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc new file mode 100644 index 0000000000..eb2b0aacf7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/gemm_support.h" + +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace gemm_support { + +struct RefCountedGemmContext { + gemmlowp::GemmContext* gemm_context_ = nullptr; + int num_references_ = 0; +}; + +void IncrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + ptr = new RefCountedGemmContext; + ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->num_references_ = 0; + context->gemm_context = ptr; + } + ptr->num_references_++; +} + +void DecrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to DecrementUsageCounter() not preceded by " + "IncrementUsageCounter()"); + } + if (--ptr->num_references_ == 0) { + delete ptr->gemm_context_; + delete ptr; + context->gemm_context = nullptr; + } +} + +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->gemm_context_; +} + +void SetMaxNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + GetFromContext(context)->set_max_num_threads(num_threads); + DecrementUsageCounter(context); +} + +} // namespace gemm_support +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h new file mode 100644 index 0000000000..b531959ffb --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace gemm_support { + +// Returns the GemmContext stored in 'context', allowing multiple ops to +// share a single object, as long as they share a TfLiteContext. The caller +// must ensure that this is called between IncrementUsageCounter() and +// DecrementUsageCounter(). For example, in the implementation of an op: +// void* Init(TfLiteContext* context, const char*, size_t) { +// gemm_support::IncrementUsageCounter(context); +// return nullptr; +// } +// void Free(TfLiteContext* context, void*) { +// gemm_support::DecrementUsageCounter(context); +// } +// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { +// auto* gemm_context = gemm_support::GetFromContext(context); +// } +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context); + +// Let the framework know that the GemmContext stored in 'context' will be used +// by an op. If necessary a new GemmContext is created and placed in 'context'. +void IncrementUsageCounter(TfLiteContext* context); + +// Let the framework know that the op stopped using the GemmContext stored in +// 'context'. If there are no more usages the GemmContext will be deleted. +void DecrementUsageCounter(TfLiteContext* context); + +// Set the maximum number threads available for gemmlowp operations. +void SetMaxNumThreads(TfLiteContext* context, int num_threads); + +} // namespace gemm_support +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc new file mode 100644 index 0000000000..3b82601d11 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from hashtable. +// +// Input: +// Tensor[0]: Hash key to lookup, dim.size == 1, int32 +// Tensor[1]: Key of hashtable, dim.size == 1, int32 +// *MUST* be sorted in ascending order. +// Tensor[2]: Value of hashtable, dim.size >= 1 +// Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Output[0].dim[0] == Tensor[0].dim[0], num of lookups +// Each item in output is a raw bytes copy of corresponding item in input. +// When key does not exist in hashtable, the returned bytes are all 0s. +// +// Output[1].dim = { Tensor[0].dim[0] }, num of lookups +// Each item indicates whether the corresponding lookup has a returned value. +// 0 for missing key, 1 for found key. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +int greater(const void* a, const void* b) { + return *static_cast(a) - *static_cast(b); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* key = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); + TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 2); + TF_LITE_ENSURE(context, NumDimensions(value) >= 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), + SizeOfDimension(value, 0)); + if (value->type == kTfLiteString) { + TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); + } + + TfLiteTensor* hits = GetOutput(context, node, 1); + TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); + TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); + hitSize->data[0] = SizeOfDimension(lookup, 0); + + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, value->type, output->type); + + TfLiteStatus status = kTfLiteOk; + if (output->type != kTfLiteString) { + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + outputSize->data[0] = SizeOfDimension(lookup, 0); + for (int i = 1; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + status = context->ResizeTensor(context, output, outputSize); + } + if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { + status = kTfLiteError; + } + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* hits = GetOutput(context, node, 1); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* key = GetInput(context, node, 1); + TfLiteTensor* value = GetInput(context, node, 2); + + const int num_rows = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / num_rows; + void* pointer = nullptr; + DynamicBuffer buf; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = -1; + pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, + sizeof(int32_t), greater); + if (pointer != nullptr) { + idx = (reinterpret_cast(pointer) - (key->data.raw)) / + sizeof(int32_t); + } + + if (idx >= num_rows || idx < 0) { + if (output->type == kTfLiteString) { + buf.AddString(nullptr, 0); + } else { + memset(output->data.raw + i * row_bytes, 0, row_bytes); + } + hits->data.uint8[i] = 0; + } else { + if (output->type == kTfLiteString) { + buf.AddString(GetString(value, idx)); + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + hits->data.uint8[i] = 1; + } + } + if (output->type == kTfLiteString) { + buf.WriteToTensor(output); + } + + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_HASHTABLE_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc new file mode 100644 index 0000000000..916a23225e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class HashtableLookupOpModel : public SingleOpModel { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +// TODO(yichengfan): write more tests that exercise the details of the op, +// such as lookup errors and variable input shapes. +TEST(HashtableLookupOpTest, Test2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, 0, 1, 1, + })); +} + +TEST(HashtableLookupOpTest, Test1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(HashtableLookupOpTest, TestString) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue({"Hello", "", "Hi"}); + + m.Invoke(); + + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({ + "Hi", // 2-nd item + "", // Not found + "Hello", // 0-th item + "", // 1-st item + })); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD new file mode 100644 index 0000000000..288534099b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -0,0 +1,359 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +tflite_deps_intel = [ + "@arm_neon_2_x86_sse", +] + +NEON_FLAGS_IF_APPLICABLE = select({ + ":arm": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armeabi-v7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armv7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + "//conditions:default": [ + "-O3", + ], +}) + +cc_library( + name = "types", + srcs = [], + hdrs = [ + "compatibility.h", + "types.h", + ], +) + +config_setting( + name = "arm", + values = { + "cpu": "arm", + }, +) + +config_setting( + name = "arm64-v8a", + values = { + "cpu": "arm64-v8a", + }, +) + +config_setting( + name = "armv7a", + values = { + "cpu": "armv7a", + }, +) + +config_setting( + name = "armeabi-v7a", + values = { + "cpu": "armeabi-v7a", + }, +) + +config_setting( + name = "haswell", + values = { + "cpu": "haswell", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "ios_armv7", + values = { + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "k8", + values = { + "cpu": "k8", + }, +) + +config_setting( + name = "x86", + values = { + "cpu": "x86", + }, +) + +config_setting( + name = "x86_64", + values = { + "cpu": "x86_64", + }, +) + +config_setting( + name = "darwin", + values = { + "cpu": "darwin", + }, +) + +cc_library( + name = "optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":types", + ":round", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "optimized", + hdrs = [ + "optimized/eigen_spatial_convolutions.h", + "optimized/eigen_tensor_reduced_instantiations_oss.h", + "optimized/multithreaded_conv.h", + "tensor.h", + ], + deps = [ + ":optimized_base", + ":types", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:context", + "//third_party/eigen3", + ], +) + +cc_test( + name = "tensor_test", + srcs = ["tensor_test.cc"], + deps = [ + ":reference", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "round", + srcs = [], + hdrs = ["round.h"], +) + +cc_library( + name = "quantization_util", + srcs = ["quantization_util.cc"], + hdrs = [ + "compatibility.h", + "quantization_util.h", + ], + deps = [":round"], +) + +cc_test( + name = "quantization_util_test", + srcs = ["quantization_util_test.cc"], + deps = [ + ":quantization_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/reference_ops.h", + ], + deps = [ + ":round", + ":types", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "reference", + hdrs = ["tensor.h"], + deps = [ + ":types", + "//tensorflow/contrib/lite:context", + ], +) + +cc_library( + name = "portable_tensor_utils", + srcs = [ + "reference/portable_tensor_utils.cc", + ], + hdrs = [ + "reference/portable_tensor_utils.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite/kernels:op_macros", + ], +) + +cc_library( + name = "neon_tensor_utils", + srcs = [ + "optimized/neon_tensor_utils.cc", + ], + hdrs = [ + "optimized/neon_tensor_utils.h", + "optimized/tensor_utils_impl.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + ":cpu_check", + ":portable_tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + ], +) + +cc_library( + name = "tensor_utils", + srcs = [ + "tensor_utils.cc", + ], + hdrs = [ + "optimized/tensor_utils_impl.h", + "reference/portable_tensor_utils.h", + "tensor_utils.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":arm": [ + ":neon_tensor_utils", + ], + ":arm64-v8a": [ + ":neon_tensor_utils", + ], + ":armeabi-v7a": [ + ":neon_tensor_utils", + ], + ":armv7a": [ + ":neon_tensor_utils", + ], + ":ios_armv7": [ + ":neon_tensor_utils", + ], + ":ios_arm64": [ + ":neon_tensor_utils", + ], + "//conditions:default": [ + ":portable_tensor_utils", + ], + }), +) + +cc_test( + name = "tensor_utils_test", + srcs = ["tensor_utils_test.cc"], + copts = NEON_FLAGS_IF_APPLICABLE, + linkopts = select({ + "//tensorflow:android": [ + "-fPIE -pie", + ], + "//conditions:default": [], + }), + linkstatic = 1, + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "cpu_check", + hdrs = [ + "optimized/cpu_check.h", + ], + deps = [ + ] + select( + { + "//tensorflow:android": [ + "@androidndk//:cpufeatures", + ], + "//conditions:default": [], + }, + ), +) + +exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h new file mode 100644 index 0000000000..28f19a2506 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#include +#endif + +#if defined __GNUC__ && defined __SSE4_1__ +#define USE_NEON + +#define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#pragma GCC diagnostic ignored "-Wattributes" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnarrowing" +#pragma GCC diagnostic ignored "-Wsequence-point" + +#include "NEON_2_SSE.h" + +#pragma GCC diagnostic pop +#endif +#endif + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +inline void GetActivationMinMax(FusedActivationFunctionType ac, + float* output_activation_min, + float* output_activation_max) { + switch (ac) { + case FusedActivationFunctionType::kNone: + *output_activation_min = std::numeric_limits::lowest(); + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu: + *output_activation_min = 0.f; + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu1: + *output_activation_min = -1.f; + *output_activation_max = 1.f; + break; + case FusedActivationFunctionType::kRelu6: + *output_activation_min = 0.f; + *output_activation_max = 6.f; + break; + } +} + +inline float ActivationFunctionWithMinMax(float x, float output_activation_min, + float output_activation_max) { + return std::min(std::max(x, output_activation_min), output_activation_max); +} + +// Legacy function, left for compatibility only. +template +float ActivationFunction(float x) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + return ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); +} + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h new file mode 100644 index 0000000000..796a03566a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ + +#include +#include +#include + +#ifndef TFLITE_DCHECK +#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_EQ +#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GE +#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GT +#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LE +#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LT +#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false) +#endif + +// TODO(ahentz): Clean up: We should stick to the DCHECK versions. +#ifndef TFLITE_CHECK +#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_EQ +#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GE +#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GT +#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LE +#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LT +#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort() +#endif + +// TODO(ahentz): Clean up. +using uint8 = std::uint8_t; +using int16 = std::int16_t; +using uint16 = std::uint16_t; +using int32 = std::int32_t; +using uint32 = std::uint32_t; + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h new file mode 100644 index 0000000000..dea46cc120 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ + +namespace tflite { + +#ifdef __ANDROID__ +#include "ndk/sources/android/cpufeatures/cpu-features.h" + +// Runtime check for Neon support on Android. +inline bool TestCPUFeatureNeon() { +#ifdef __aarch64__ + // ARM-64 always has NEON support. + return true; +#else + static bool kUseAndroidNeon = + (android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_ARMv7 && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON); + return kUseAndroidNeon; +#endif // __aarch64__ +} + +#elif __ARM_NEON + +inline bool TestCPUFeatureNeon() { + return true; +} + +#else + +inline bool TestCPUFeatureNeon() { + return false; +} + +#endif + +} // namespace tflite + +// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both +// enabled at build time and detected at runtime, or PortableSomeFunc(args) +// otherwise. +#ifdef __ARM_ARCH_5TE__ +// Neon isn't available at all on ARMv5. +#define NEON_OR_PORTABLE(funcname, ...) Portable##funcname(__VA_ARGS__) +#else +#define NEON_OR_PORTABLE(funcname, ...) \ + TestCPUFeatureNeon() ? Neon##funcname(__VA_ARGS__) \ + : Portable##funcname(__VA_ARGS__) +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h new file mode 100644 index 0000000000..974611f52a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -0,0 +1,987 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of float DepthwiseConv + +template +struct FloatDepthwiseConvKernel {}; + +#ifdef USE_NEON + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], input[0], filter[0]); + acc[1] = vmlaq_f32(acc[1], input[1], filter[1]); + acc[2] = vmlaq_f32(acc[2], input[2], filter[0]); + acc[3] = vmlaq_f32(acc[3], input[3], filter[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + const float32x2_t filters = vld1_f32(filter_ptr); + const float32x4_t filters_dup2 = vcombine_f32(filters, filters); + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + const float32x4_t input = vld1q_f32(input_ptr); + input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filters_dup2); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time + for (; outp < num_output_pixels; outp++) { + // Load the inputs + const float32x2_t input = vld1_f32(input_ptr); + input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmla_f32(acc, input, filters); + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3); + local_filter_ptr += 16; + // Load the inputs + float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0); + float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1); + float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2); + float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3); + local_input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + // Multiply-accumulate + acc_0 = vmlaq_f32(acc_0, input_0, filter_0); + acc_1 = vmlaq_f32(acc_1, input_1, filter_1); + acc_2 = vmlaq_f32(acc_2, input_2, filter_2); + acc_3 = vmlaq_f32(acc_3, input_3, filter_3); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x4_t filter; + filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + float32x4_t input; + input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc; + acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const float input_val = *local_input_ptr++; + const float filter_val = *local_filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1); + acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + float32x4x2_t input_dup2[2]; + for (int i = 0; i < 2; i++) { + const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i); + input_dup2[i] = vzipq_f32(input, input); + } + local_input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]); + acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]); + acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]); + acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x2_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1_f32(local_filter_ptr + 2 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float32x4_t input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0); + acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + const float32x4_t filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0); + acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs + const float input_val = *local_input_ptr++; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc_buffer_ptr[i] += local_filter_ptr[i] * input_val; + } + local_filter_ptr += 2; + acc_buffer_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5); + float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6); + float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5); + float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6); + float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val); + acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val); + acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + for (int ic = 0; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x2_t filter = vld1_f32(filter_ptr); + float32x4_t filter_x4 = vcombine_f32(filter, filter); + int outp = 0; + + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x2_t input_1 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x2_t input_2 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x4_t input = vcombine_f32(input_1, input_2); + + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter_x4); + + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x2_t input = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmla_f32(acc, input, filter); + + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x4_t filter = vld1q_f32(filter_ptr); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input = vld1q_f32(input_ptr); + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + input_ptr += input_ptr_increment; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width, + const float* input_data, int pad_width, + int depth_multiplier, int filter_width, + const float* filter_data, + int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + FloatDepthwiseConvKernel::Run(num_output_pixels, + input_depth, + depth_multiplier, + input_ptr, + input_ptr_increment, + filter_base_ptr, + acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized. +inline void FloatDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const float* input_data, + int pad_width, int depth_multiplier, int filter_width, + const float* filter_data, int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const float* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const float input_val = *input_ptr++; + for (int m = 0; m < depth_multiplier; m++) { + const float filter_val = *filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const float* bias_data, + float* acc_buffer) { + // TODO(benoitjacob): This might need optimized specializations + // for small output_depth values, if that ever becomes an important + // case (like it was for some quantized DepthwiseConv cases). + for (int i = 0; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + float acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + FloatDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16) + +#endif // USE_NEON + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = FloatDepthwiseConvAccumRowGeneric; + } + + // Now that we have determined row_accum_func, we can start work. + float* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func(stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], + out_x_buffer_start, out_x_buffer_end, output_depth, + acc_buffer); + } + // Finished accumulating. Now store to destination. + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +// TODO(benoitjacob) optimized code goes here +#ifdef USE_NEON + // Handle 16 values at a time + for (; i <= num_output_values - 16; i += 16) { + float32x4_t acc[4]; + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_f32(acc_buffer + i + 4 * k); + } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } + for (int k = 0; k < 4; k++) { + vst1q_f32(output_ptr + 4 * k, acc[k]); + } + output_ptr += 16; + } + // Handle 4 values at a time + for (; i <= num_output_values - 4; i += 4) { + float32x4_t acc = vld1q_f32(acc_buffer + i); + + acc = vmaxq_f32(vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc)); + + vst1q_f32(output_ptr, acc); + output_ptr += 4; + } +#endif + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + float acc = acc_buffer[i]; + acc = std::max(output_activation_min, + std::min(output_activation_max, acc)); + + *output_ptr++ = acc; + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h new file mode 100644 index 0000000000..051ed2a2c4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -0,0 +1,1916 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of quantized DepthwiseConv + +template +struct QuantizedDepthwiseConvKernel {}; + +#ifdef USE_NEON +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(filter_ptr); + filter_u8.val[1] = vld1_u8(filter_ptr + 8); + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])), + vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]), + vget_low_s16(input_dup2.val[i])); + acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0])); + acc[1] = + vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1])); + acc[3] = + vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + acc[0] = vld1q_s32(acc_buffer_ptr); + acc[1] = vld1q_s32(acc_buffer_ptr + 4); + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc[0]); + vst1q_s32(acc_buffer_ptr + 4, acc[1]); + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter), + vget_low_s16(input_dup2.val[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4x2_t input_dup2 = vzip_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + int outp = 0; + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4_t input_dup2 = vzip_s16(input, input).val[0]; + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input_dup2); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input)); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x2_t acc = vld1_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, input, 0); + acc[1] = vmlal_lane_s16(acc[1], filter, input, 1); + acc[2] = vmlal_lane_s16(acc[2], filter, input, 2); + acc[3] = vmlal_lane_s16(acc[3], filter, input, 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vmlal_n_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + } + input_ptr += 16; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i])); + acc[2 * i + 1] = + vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), + vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), + vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), + vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), + vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), + vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), + vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), + vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), + vget_high_s16(input), 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // We will have to duplicate bytes in a NEON register, 3-fold. + // We will do that by register-level table-look-up using VTBL instructions. + // Here we prepare the registers containing the table-lookup indices. + static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, + {2, 3, 3, 3, 4, 4, 4, 5}, + {5, 5, 6, 6, 6, 7, 7, 7}}; + uint8x8_t dup3_indices[3]; + for (int i = 0; i < 3; i++) { + dup3_indices[i] = vld1_u8(dup3_indices_array[i]); + } + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[3]; + uint8x8x3_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + filter_u8.val[2] = vld1_u8(local_filter_ptr + 16); + local_filter_ptr += 24; + for (int i = 0; i < 3; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, duplicate 3-fold, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + + uint8x8_t input_u8_dup3[3]; + for (int i = 0; i < 3; i++) { + input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]); + } + int16x8_t input_dup3[3]; + for (int i = 0; i < 3; i++) { + const int16x8_t input_s16_dup3 = + vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i])); + input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4x3_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16); + } + // Multiply-accumulate + for (int j = 0; j < 3; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]), + vget_low_s16(filter[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]), + vget_high_s16(filter[j])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]); + } + acc_buffer_ptr += 24; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 3; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 3; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + local_filter_ptr += 16; + for (int i = 0; i < 2; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, add input_offset, duplicate 2-fold. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Load the accumulators from acc_buffer. + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Multiply-accumulate. + for (int j = 0; j < 2; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]), + vget_low_s16(input_dup2.val[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]), + vget_high_s16(input_dup2.val[j])); + } + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs. + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 2; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1); + local_filter_ptr += 16; + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0); + uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1); + local_input_ptr += 16; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0)); + acc_1 = + vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1)); + acc_3 = + vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr); + local_filter_ptr += 8; + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = + vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + const int16 filter_val = *local_filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += input_ptr_increment; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]), + vget_low_s16(filter[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]), + vget_high_s16(filter[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input); + acc[2 * i + 1] = + vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2); + uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2)); + int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset)); + filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5); + int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6); + int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input); + acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input); + acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input); + acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input); + acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint16x4_t input_u16 = vdup_n_u16(0); + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 0); + input_ptr += input_ptr_increment; + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = vreinterpret_s16_u16( + vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16)))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + if (num_output_pixels <= 0) { + return; + } + + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle one output pixel at a time until second to the last pixel. Second + // to the last because we read eight input pixels while only processing + // four. + for (; outp < num_output_pixels - 1; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle the last output pixel. + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4); + int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset)); + filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset)); + int16x4_t filter_0 = vget_low_s16(filter_s16_0); + int16x4_t filter_1 = vget_high_s16(filter_s16_0); + int16x4_t filter_2 = vget_high_s16(filter_s16_1); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(input_ptr); + uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4); + input_ptr += input_ptr_increment; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + + // Multiply-accumulate + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0); + acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1); + acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2); + + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + + acc_buffer_ptr += 12; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void QuantizedDepthwiseConvAccumRow( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + QuantizedDepthwiseConvKernel< + kAllowStrided, kFixedInputDepth, + kFixedDepthMultiplier>::Run(num_output_pixels, input_depth, + depth_multiplier, input_ptr, input_offset, + input_ptr_increment, filter_base_ptr, + filter_offset, acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of DepthwiseConvAccumRow, portable, non-templatized. +inline void QuantizedDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const uint8* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const int16 input_val = *input_ptr++ + input_offset; + for (int m = 0; m < depth_multiplier; m++) { + const int16 filter_val = *filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const int32* bias_data, + int32* acc_buffer) { + int i = 0; +#ifdef USE_NEON + if (output_depth == 1) { + const int32x4_t b = vdupq_n_s32(bias_data[0]); + for (; i <= num_output_pixels - 16; i += 16) { + vst1q_s32(acc_buffer + i + 0, b); + vst1q_s32(acc_buffer + i + 4, b); + vst1q_s32(acc_buffer + i + 8, b); + vst1q_s32(acc_buffer + i + 12, b); + } + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + i, b); + } + } else if (output_depth == 2) { + int32x4_t b = vdupq_n_s32(bias_data[0]); + b = vsetq_lane_s32(bias_data[1], b, 1); + b = vsetq_lane_s32(bias_data[1], b, 3); + for (; i <= num_output_pixels - 8; i += 8) { + vst1q_s32(acc_buffer + 2 * i + 0, b); + vst1q_s32(acc_buffer + 2 * i + 4, b); + vst1q_s32(acc_buffer + 2 * i + 8, b); + vst1q_s32(acc_buffer + 2 * i + 12, b); + } + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 2 * i, b); + } + } else if (output_depth == 4) { + const int32x4_t b = vld1q_s32(bias_data); + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + 4 * i + 0, b); + vst1q_s32(acc_buffer + 4 * i + 4, b); + vst1q_s32(acc_buffer + 4 * i + 8, b); + vst1q_s32(acc_buffer + 4 * i + 12, b); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 4 * i, b); + } + } else if (output_depth == 8) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + vst1q_s32(acc_buffer + 8 * i + 8, b0); + vst1q_s32(acc_buffer + 8 * i + 12, b1); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + } + } else if (output_depth == 16) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + const int32x4_t b2 = vld1q_s32(bias_data + 8); + const int32x4_t b3 = vld1q_s32(bias_data + 12); + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 16 * i + 0, b0); + vst1q_s32(acc_buffer + 16 * i + 4, b1); + vst1q_s32(acc_buffer + 16 * i + 8, b2); + vst1q_s32(acc_buffer + 16 * i + 12, b3); + } + } +#endif + for (; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + int32 acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + QuantizedDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3) +#endif // USE_NEON + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = QuantizedDepthwiseConvAccumRowGeneric; + } + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // Now that we have determined row_accum_func, we can start work. + uint8* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func( + stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + input_offset, pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], filter_offset, + out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer); + } + // Finished accumulating int32 values. Now need to convert them to + // the final 8bit form and store them. + gemmlowp::ScopedProfilingLabel label("downquantize+store"); + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +#ifdef USE_NEON + using gemmlowp::RoundingDivideByPOT; + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + const int32x4_t output_activation_min_vec = + vdupq_n_s32(output_activation_min); + const int32x4_t output_activation_max_vec = + vdupq_n_s32(output_activation_max); + // Handle 16 values at once. + // This allows us to issue 4 mutually independent int32 + // multiplications (vqrdmulh), which should alleviate most of their + // high latency. + for (; i <= num_output_values - 16; i += 16) { + int32x4_t acc[4]; + for (int j = 0; j < 4; j++) { + acc[j] = vld1q_s32(acc_buffer + i + 4 * j); + } + + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + // Add the output offset. + for (int j = 0; j < 4; j++) { + acc[j] = vaddq_s32(acc[j], output_offset_vec); + } + // Apply the activation function. + for (int j = 0; j < 4; j++) { + acc[j] = vmaxq_s32(acc[j], output_activation_min_vec); + } + for (int j = 0; j < 4; j++) { + acc[j] = vminq_s32(acc[j], output_activation_max_vec); + } + // Saturating cast to uint8 and store to destination. + int16x4_t acc_s16[4]; + for (int j = 0; j < 4; j++) { + acc_s16[j] = vqmovn_s32(acc[j]); + } + const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]); + const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]); + const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0); + const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1); + vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1)); + output_ptr += 16; + } + // Handle 8 values at once. + // Not as good as 16 (now we're only issuing 2 mutually independent + // vqrdmulh instructions, so we're probably paying for their high + // latency). + for (; i <= num_output_values - 8; i += 8) { + int32x4_t acc0 = vld1q_s32(acc_buffer + i); + int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + // Add the output offset. + acc0 = vaddq_s32(acc0, output_offset_vec); + acc1 = vaddq_s32(acc1, output_offset_vec); + // Apply the activation function. + acc0 = vmaxq_s32(acc0, output_activation_min_vec); + acc1 = vmaxq_s32(acc1, output_activation_min_vec); + acc0 = vminq_s32(acc0, output_activation_max_vec); + acc1 = vminq_s32(acc1, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc0_s16 = vqmovn_s32(acc0); + const int16x4_t acc1_s16 = vqmovn_s32(acc1); + const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_u8(output_ptr, res_u8); + output_ptr += 8; + } + // Handle 4 values at once. Now we're paying the full price of the + // high latency of vqrdmulh. Also, storing only 4 bytes at the end + // (without any alignment) can only be done 1 byte at a time. + // Yet, that is still worth doing to minimize the amount of leftover + // that will have to go through the very slow scalar code. + for (; i <= num_output_values - 4; i += 4) { + int32x4_t acc = vld1q_s32(acc_buffer + i); + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + // Add the output offset. + acc = vaddq_s32(acc, output_offset_vec); + // Apply the activation function. + acc = vmaxq_s32(acc, output_activation_min_vec); + acc = vminq_s32(acc, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc_s16 = vqmovn_s32(acc); + const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_lane_u8(output_ptr + 0, res_u8, 0); + vst1_lane_u8(output_ptr + 1, res_u8, 1); + vst1_lane_u8(output_ptr + 2, res_u8, 2); + vst1_lane_u8(output_ptr + 3, res_u8, 3); + output_ptr += 4; + } +#endif // USE_NEON + + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + int32 acc = acc_buffer[i]; + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + *output_ptr++ = static_cast(acc); + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h new file mode 100644 index 0000000000..8004c24a99 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -0,0 +1,231 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. +// TODO(petewarden) - move this to a common location in Eigen itself. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// NOTE: Eigen is slightly different internally and externally. We need to +// hack the unsupported/Eigen/CXX11/Tensor header instantiation macros at +// specific places, so we need two copies of the hacked file, one for +// internal and one for external. +// If you have trouble simply undef out the reducer macro e.g. +// TFLITE_REDUCE_INSTANTIATIONS_GOOGLE, but be aware this will make +// the binary much bigger! +#define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE +#define Eigen EigenForTFLite +#if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h" +#elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h" +#else +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#endif + + +namespace Eigen { + +/** SpatialConvolution + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 2D convolution over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 3 or more + * (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The input and the kernel must both be in col-major layout. The result will + * also be in col-major layout. + * + * If col_in_stride, row_in_stride > 1, then applies convolution with holes + * (aka atrous convolution), sampling every col_in_stride, row_in_stride input + * pixels. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be filters, height, width (and + * others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + */ +template +EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE static const typename internal::conditional< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp > > >, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel> > > >::type + SpatialConvolution(const Input& input, const Kernel& kernel, + const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1, + const PaddingType padding_type = PADDING_SAME, + const DenseIndex row_in_stride = 1, + const DenseIndex col_in_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + kern(kernel); + + EIGEN_STATIC_ASSERT( + internal::traits::Layout == internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + const bool isColMajor = (internal::traits::Layout == ColMajor); + + const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; + + const DenseIndex kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const DenseIndex kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + const TensorIndex InputRows = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex InputCols = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + + TensorIndex out_height; + TensorIndex out_width; + switch (padding_type) { + case PADDING_VALID: + out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / + static_cast(row_stride)); + out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / + static_cast(col_stride)); + break; + case PADDING_SAME: + out_height = numext::ceil(InputRows / static_cast(row_stride)); + out_width = numext::ceil(InputCols / static_cast(col_stride)); + break; + default: + // Initialize unused variables to avoid a compiler warning + out_height = 0; + out_width = 0; + eigen_assert(false && "unexpected padding"); + } + + // Molds the output of the patch extraction code into a 2d tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[1] = out_height * out_width; + for (int i = 3; i < NumDims; ++i) { + pre_contract_dims[1] *= in.dimension(i); + } + } else { + pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[0] = out_height * out_width; + for (int i = 0; i < NumDims - 3; ++i) { + pre_contract_dims[0] *= in.dimension(i); + } + } + + // Molds the output of the contraction into the shape expected by the used + // (assuming this is ColMajor): + // - 1st dim: kernel filters + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = out_height; + post_contract_dims[2] = out_width; + for (int i = 3; i < NumDims; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelFilters; + post_contract_dims[NumDims - 2] = out_height; + post_contract_dims[NumDims - 3] = out_width; + for (int i = 0; i < NumDims - 3; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } + + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels * kernelRows * kernelCols; + } else { + kernel_dims[0] = kernelChannels * kernelRows * kernelCols; + kernel_dims[1] = kernelFilters; + } + // TODO(yangke): choose() is defined in TensorContraction.h -- consider + // moving it to somewhere more "common". + return + input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); +} + +} // end namespace Eigen + +// clang-format on + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h new file mode 100644 index 0000000000..7f78f69360 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// clang-format off + +#include + +#include +#include +#include +#include +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + + +// Because some programs may link Eigen in through other frameworks with +// different flags, we can run into multiple definition issues if we don't have +// a private namespace for our versions. This is a nasty hack, but a similar +// approach is used elsewhere to handle the problem, so it should be stable. +#define Eigen EigenForTFLite + +#include "Eigen/src/Core/util/StaticAssert.h" +#include "unsupported/Eigen/CXX11/Core" +#include "unsupported/Eigen/SpecialFunctions" + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "Eigen/Core" + +// Beware: the order of the include matters to some compilers. For example +// TensorIndexList.h should be included before TensorDimensions.h in order to +// use index lists to encode tensor dimensions when compiling with llvm. +// We're defining this ourselves rather than using the Eigen Tensor header file +// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to +// reduce binary size. +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h new file mode 100644 index 0000000000..1d5c316194 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is essentially unsupported/CXX11/Eigen/Tensor.h +// TODO(petewarden) - move this to a common location in Eigen itself. + +// clang-format off + + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ + + +#include "Eigen/Core" + +#if defined(EIGEN_USE_SYCL) +#undef min +#undef max +#undef isnan +#undef isinf +#undef isfinite +#include +#include +#include +#include +#include +#endif +#include +#include +#include + + + + + +#ifdef _WIN32 +typedef __int16 int16_t; +typedef unsigned __int16 uint16_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#include +#else +#include +#include +#endif + +#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900 +#include +#endif + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + +// #if defined(EIGEN_USE_LIBXSMM) +// #include "libxsmm.h" +// #endif + +#ifdef EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/ThreadPool" +#endif + + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "unsupported/Eigen/SpecialFunctions" +#include "unsupported/Eigen/CXX11/src/util/CXX11Meta.h" +#include "unsupported/Eigen/CXX11/src/util/MaxSizeVector.h" + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" + +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h new file mode 100644 index 0000000000..b3615f4658 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace multithreaded_ops { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + Eigen::ThreadPool* pool_ = nullptr; +}; + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { + const int thread_count = 4; + static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); + static EigenThreadPoolWrapper* thread_pool_wrapper = + new EigenThreadPoolWrapper(tp); + static Eigen::ThreadPoolDevice* device = + new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); + return *device; +} + +// Shorthands for the types we need when interfacing with the EigenTensor +// library. +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenMatrix; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenMatrix; + +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenTensor; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenTensor; + +// Utility functions we need for the EigenTensor API. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, EigenMatrix out, ConstEigenMatrix in0, + ConstEigenMatrix in1, + const Eigen::array, 1>& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); + } +}; + +template +class EigenTensorConvFunctor { + private: + Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingValid: + return Eigen::PADDING_VALID; + case kTfLitePaddingSame: + return Eigen::PADDING_SAME; + case kTfLitePaddingUnknown: + assert(false); // should never get here. + return Eigen::PADDING_VALID; + } + return Eigen::PADDING_SAME; // Prevent compiler warning about missing + // return + } + + public: + void operator()(const T* input_data, T* im2col_buffer, int input_batches, + int input_height, int input_width, int input_depth, + const T* filter_data, int filter_height, int filter_width, + int filter_count, int stride_rows, int stride_cols, + int pad_width, int pad_height, TfLitePadding padding, + T* output_data, int output_height, int output_width) { + const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); + + const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && + stride_rows == 1 && stride_cols == 1); + if (is_1x1_kernel) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + const int conv_width = output_height * output_width; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, conv_width, filter_count); + ConstEigenMatrix input(input_data, conv_width, input_depth); + ConstEigenMatrix filter(filter_data, input_depth, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else if (filter_height == input_height && filter_width == input_width && + pad_width == 0 && pad_height == 0) { + // If the input data and filter have the same height/width, + // the 2D convolution is reduced to matrix multiplication. + const int k = // Length of reduction dimension. + filter_width * filter_height * input_depth; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, 1, filter_count); + ConstEigenMatrix input(input_data, 1, k); + ConstEigenMatrix filter(filter_data, k, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else { + EigenTensor output(output_data, input_batches, output_height, + output_width, filter_count); + ConstEigenTensor input(input_data, input_batches, input_height, + input_width, input_depth); + ConstEigenTensor filter(filter_data, filter_height, filter_width, + input_depth, filter_count); + output.device(device) = + Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows, + TfLitePadding2EigenPadding(padding)); + } + } +}; + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, TfLitePadding padding, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + EigenTensorConvFunctor conv_functor; + conv_functor(input_data, im2col_data, batches, input_height, input_width, + input_depth, filter_data, filter_height, filter_width, + output_depth, stride_height, stride_width, pad_height, pad_width, + padding, output_data, output_height, output_width); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace multithreaded_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc new file mode 100644 index 0000000000..bf0bdfb1fb --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +#ifdef USE_NEON + +#include +#define kFloatWeightsPerNeonLane 4 + +namespace tflite { +namespace tensor_utils { + +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + const int kUnrollSize = 2; + for (int b = 0; b < n_batch; b++) { + float* result_in_batch = result + b * m_rows * result_stride; + const float* vector_in_batch = vector + b * m_cols; + + const float* matrix_ptr0 = matrix; + // If there is only 1 row, we don't want to assign an illegal pointer. + const float* matrix_ptr1 = nullptr; + if (m_rows > 1) { + matrix_ptr1 = matrix + m_cols; + } + + // Cahce the vector. + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); + } + + // Main matrix by vector multiplication loop, which handles two rows of + // matrix by vector multiplication. + for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + float32x4_t acc1_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + *(result_in_batch + result_stride) += + (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) + + vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + *(result_in_batch + result_stride) += + matrix_ptr1[c] * vector_in_batch[c]; + } + matrix_ptr0 += kUnrollSize * m_cols; + matrix_ptr1 += kUnrollSize * m_cols; + result_in_batch += kUnrollSize * result_stride; + } + for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + } + matrix_ptr0 += m_cols; + result_in_batch += result_stride; + } + } + delete[] vector_cache_float32x4; +} + +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply 4 float + float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], mul_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = vector1[v] * vector2[v]; + } +} + +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + float32x4_t acc_32x4 = vld1q_f32(result + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], acc_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] += vector1[v] * vector2[v]; + } +} + +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(v_size / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); + } + + float* result_ptr = result; + const float* batch_vector_ptr = batch_vector; + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vectors. + float32x4_t result_f32x4 = vld1q_f32(result_ptr + v); + float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v); + // Multiply-accumulate. + result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, + vector_cache_float32x4[v >> 2]); + // Store. + vst1q_f32(result_ptr + v, result_f32x4); + } + // Postamble loop + for (int v = postamble_start; v < v_size; v++) { + result_ptr[v] += vector[v] * batch_vector_ptr[v]; + } + // Update the pointers. + result_ptr += v_size; + batch_vector_ptr += v_size; + } + delete[] vector_cache_float32x4; +} + +void NeonSub1Vector(const float* vector, int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + float32x4_t one_f32x4 = vmovq_n_f32(1.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from the current pointers of the input column and + // subtract from 1. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = 1.0f - vector[v]; + } +} + +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // Replicate abs_limit and -abs_limit in two vectors. + const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit); + const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit); + + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vector. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + // Clip between abs_limit and -abs_limit. + float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4); + result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result[v] = (abs_limit < vector[v]) ? abs_limit : vector[v]; + result[v] = (-abs_limit > result[v]) ? -abs_limit : result[v]; + } +} + +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t acc_32x4 = vmovq_n_f32(0.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + } + + float result = (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result += vector1[v] * vector2[v]; + } + return result; +} + +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + // If reduction_size is not divisible by kWeightsPerNeonLane, we cannot use + // the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + reduction_size - (reduction_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t sum_f32x4 = vmovq_n_f32(0.0); + for (int r = 0; r < postamble_start; r += kFloatWeightsPerNeonLane) { + float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r); + sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4); + } + output_vector[o] += + (vgetq_lane_f32(sum_f32x4, 0) + vgetq_lane_f32(sum_f32x4, 1) + + vgetq_lane_f32(sum_f32x4, 2) + vgetq_lane_f32(sum_f32x4, 3)); + input_vector_ptr += postamble_start; + + // Postamble loop. + for (int r = postamble_start; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) { + // This variable keeps track of the next to the last index which is being + // copied to make sure we are not out of the vector boundary. + int last_index_copy = kFloatWeightsPerNeonLane; + int current_index_copy = 0; + while (last_index_copy < v_size) { + float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1); + vst1q_f32(vector + current_index_copy, v_f32x4); + current_index_copy += kFloatWeightsPerNeonLane; + last_index_copy += kFloatWeightsPerNeonLane; + } + // Postamble loop. + for (int i = current_index_copy; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h new file mode 100644 index 0000000000..3a4af87304 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ + +// TODO(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +namespace tflite { +namespace tensor_utils { + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vector, n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size, + result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size, + batch_vector, n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size, + n_batch, result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h new file mode 100644 index 0000000000..cd565c16a1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -0,0 +1,3715 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen vector expression. The std::conditional here is to +// construct the suitable Eigen type for the constness of the +// data. Indeed, for const data, we need to produce +// Eigen::Map> +// and not the more straightforward +// Eigen::Map> +template +using VectorMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, 1>>, + Eigen::Map>>::type; + +template +VectorMap MapAsVector(Scalar* data, const Dims& dims) { + const int size = RequiredBufferSizeForDims(dims); + return VectorMap(data, size, 1); +} + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen matrix expression. The same explanation as for VectorMap +// above also applies here. +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +using ArrayMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return ArrayMap(data, rows, cols); +} + +// TODO(b/62193649): this function is only needed as long +// as we have the --variable_batch hack. +template +MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, + const Dims& dims, + int rows) { + int cols = 1; + bool matched_rows = false; + for (int d = 0; d < N; d++) { + cols *= dims.sizes[d]; + if (cols == rows) { + matched_rows = true; + cols = 1; + } + } + TFLITE_DCHECK(matched_rows); + return MatrixMap(data, rows, cols); +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { + for (int i = 0; i < 4; i++) { + if (dims1.sizes[i] != dims2.sizes[i]) { + return false; + } + } + return true; +} + +inline void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims, + float output_activation_min, + float output_activation_max) { +#ifdef USE_NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + float* array_ptr = array_data; + float* array_end_ptr = array_ptr + array_size; + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; array_ptr != array_end_ptr; array_ptr += bias_size) { + int i = 0; + for (; i <= bias_size - 16; i += 16) { + auto b0 = vld1q_f32(bias_data + i); + auto b1 = vld1q_f32(bias_data + i + 4); + auto b2 = vld1q_f32(bias_data + i + 8); + auto b3 = vld1q_f32(bias_data + i + 12); + auto a0 = vld1q_f32(array_ptr + i); + auto a1 = vld1q_f32(array_ptr + i + 4); + auto a2 = vld1q_f32(array_ptr + i + 8); + auto a3 = vld1q_f32(array_ptr + i + 12); + auto x0 = vaddq_f32(a0, b0); + auto x1 = vaddq_f32(a1, b1); + auto x2 = vaddq_f32(a2, b2); + auto x3 = vaddq_f32(a3, b3); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(array_ptr + i, x0); + vst1q_f32(array_ptr + i + 4, x1); + vst1q_f32(array_ptr + i + 8, x2); + vst1q_f32(array_ptr + i + 12, x3); + } + for (; i <= bias_size - 4; i += 4) { + auto b = vld1q_f32(bias_data + i); + auto a = vld1q_f32(array_ptr + i); + auto x = vaddq_f32(a, b); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(array_ptr + i, x); + } + for (; i < bias_size; i++) { + array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i], + output_activation_min, + output_activation_max); + } + } +#else // not NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + for (int array_offset = 0; array_offset < array_size; + array_offset += bias_size) { + for (int i = 0; i < bias_size; i++) { + array_data[array_offset + i] = ActivationFunctionWithMinMax( + array_data[array_offset + i] + bias_data[i], output_activation_min, + output_activation_max); + } + } +#endif +} + +// legacy, for compatibility with old checked-in code +template +void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims, + output_activation_min, + output_activation_max); +} + +template +void Gemm(const Eigen::MatrixBase& lhs, const Eigen::MatrixBase& rhs, + Eigen::MatrixBase* result) { + if (rhs.cols() == 1) { + gemmlowp::ScopedProfilingLabel label("GEMV"); + result->col(0).noalias() = lhs * rhs.col(0); + } else { + gemmlowp::ScopedProfilingLabel label("GEMM"); + result->noalias() = lhs * rhs; + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnected"); + // TODO(b/62193649): this convoluted shape computation (determining + // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows) + // is because the current --variable_batch hack consists in overwriting the + // 3rd dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + // When that is fixed, this should become: + // const auto input_matrix_map = + // MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const int input_rows = ArraySize(weights_dims, 0); + const auto input_matrix_map = + MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows); + const auto filter_matrix_map = + MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void preload_l1_stream(const uint8* ptr) { +#ifdef GEMMLOWP_ARM_64 + asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :); +#else + gemmlowp::Prefetch(ptr); +#endif +} + +#ifdef USE_NEON +inline void FullyConnectedAsGEMV( + const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, + const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + static constexpr int kPeel = 4; + for (int k = 0; k < input_size; k += 64) { + preload_l1_stream(input_data + k); + } + for (int k = 0; k < kPeel * input_size; k += 64) { + preload_l1_stream(filter_data + k); + } + TFLITE_DCHECK(!(output_size % kPeel)); + const int32* bias_ptr = bias_data; + uint8* output_ptr = output_data; + for (int out = 0; out < output_size; out += kPeel) { + int32x4_t acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + acc[k] = vdupq_n_s32(0); + } + const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); + const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); + int in = 0; + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + uint8x16_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1q_u8(filter_ptr); + preload_l1_stream(filter_ptr + 64); + } + int16x8_t input_val[2]; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val[0] = vaddq_s16(input_val[0], input_offset_vec); + input_val[1] = vaddq_s16(input_val[1], input_offset_vec); + int16x8_t filter_val[kPeel][2]; + for (int k = 0; k < kPeel; k++) { + const uint8x8_t low = vget_low_u8(filter_val_u8[k]); + const uint8x8_t high = vget_high_u8(filter_val_u8[k]); + filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low)); + filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec); + filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec); + } + for (int p = 0; p < 2; p++) { + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]), + vget_low_s16(input_val[p])); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]), + vget_high_s16(input_val[p])); + } + } + } + for (; in <= input_size - 8; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + uint8x8_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1_u8(filter_ptr); + } + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t filter_val[kPeel]; + for (int k = 0; k < kPeel; k++) { + filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k])); + filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]), + vget_low_s16(input_val)); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]), + vget_high_s16(input_val)); + } + } + if (in < input_size) { + int32 buf[4 * kPeel]; + for (int k = 0; k < 4; k++) { + vst1q_s32(buf + 4 * k, acc[k]); + } + for (; in < input_size; in++) { + int lane = (in + 8 - input_size) % 4; + const int32 input_val = input_data[in] + input_offset; + for (int k = 0; k < kPeel; k++) { + int32 filter_val = + filter_data[in + (out + k) * input_size] + filter_offset; + buf[lane + 4 * k] += filter_val * input_val; + } + } + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_s32(buf + 4 * k); + } + } + + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + pairwise_reduced_acc[k] = + vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k])); + } + static_assert(kPeel == 4, "the code below currently assumes kPeel = 4"); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, output_shift); + // Add the output offset. + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + reduced = vaddq_s32(reduced, output_offset_vec); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + // Narrow values down to 8 bit unsigned, saturating. + uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16)); + // Apply the clamping from the activation function + res8 = vmax_u8(res8, vdup_n_u8(output_activation_min)); + res8 = vmin_u8(res8, vdup_n_u8(output_activation_max)); + // Store results to destination. Assumes 32bit alignment. + vst1_lane_u32(reinterpret_cast(output_ptr), + vreinterpret_u32_u8(res8), 0); + output_ptr += kPeel; + } +} +#endif // USE_NEON + +struct GemmlowpOutputPipeline { + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple< + gemmlowp::OutputStageBiasAddition, + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + Pipeline; + static Pipeline Make(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max) { + ColVectorMap bias_vector(bias_data, output_rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_offset; + quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; + quantize_down_stage.result_shift = output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(bias_addition_stage, quantize_down_stage, + clamp_stage, saturating_cast_stage); + } +}; + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); +#ifdef USE_NEON + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + if (batches == 1 && !(output_size % 4)) { + return FullyConnectedAsGEMV( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, output_data, + output_dims); + } +#endif // USE_NEON + const int filter_rows = filter_dims.sizes[1]; + const int filter_cols = filter_dims.sizes[0]; + TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1); + const int output_rows = output_dims.sizes[0]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, batches, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, batches, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +inline void ExtractPatchIntoBufferColumn( + const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, + int stride_width, int stride_height, int pad_width, int pad_height, + int in_width, int in_height, int in_depth, int single_buffer_length, + int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) { + gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn"); + // This chunk of code reshapes all the inputs corresponding to + // output (b, h, w) to a column vector in conv_buffer(:, buffer_id). + const int kwidth_times_indepth = kwidth * in_depth; + const int inwidth_times_indepth = in_width * in_depth; + const int ih_ungated_start = h * stride_height - pad_height; + const int ih_ungated_end = (ih_ungated_start + kheight); + const int ih_end = std::min(ih_ungated_end, in_height); + const int iw_ungated_start = w * stride_width - pad_width; + const int iw_ungated_end = (iw_ungated_start + kwidth); + const int iw_end = std::min(iw_ungated_end, in_width); + // If the patch is off the edge of the input image, skip writing those rows + // and columns from the patch into the output array. + const int h_offset = std::max(0, -ih_ungated_start); + const int w_offset = std::max(0, -iw_ungated_start); + const int ih_start = std::max(0, ih_ungated_start); + const int iw_start = std::max(0, iw_ungated_start); + const int single_row_num = + std::min(kwidth - w_offset, in_width - iw_start) * in_depth; + const int output_row_offset = (buffer_id * single_buffer_length); + int out_offset = + output_row_offset + (h_offset * kwidth + w_offset) * in_depth; + int in_offset = Offset(input_dims, 0, iw_start, ih_start, b); + + // Express all of the calculations as padding around the input patch. + const int top_padding = h_offset; + const int bottom_padding = (ih_ungated_end - ih_end); + const int left_padding = w_offset; + const int right_padding = (iw_ungated_end - iw_end); + assert(single_row_num == + ((kwidth - (left_padding + right_padding)) * in_depth)); + + // Write out zeroes to the elements representing the top rows of the input + // patch that are off the edge of the input image. + if (top_padding > 0) { + const int top_row_elements = (top_padding * kwidth * in_depth); + memset(conv_buffer_data + output_row_offset, byte_zero, + (top_row_elements * sizeof(T))); + } + + // If the patch is on the interior of the input image horizontally, just copy + // over the rows sequentially, otherwise add zero padding at the start or end. + if ((left_padding == 0) && (right_padding == 0)) { + for (int ih = ih_start; ih < ih_end; ++ih) { + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } else { + for (int ih = ih_start; ih < ih_end; ++ih) { + if (left_padding > 0) { + const int left_start = (out_offset - (left_padding * in_depth)); + memset(conv_buffer_data + left_start, byte_zero, + (left_padding * in_depth * sizeof(T))); + } + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + if (right_padding > 0) { + const int right_start = (out_offset + single_row_num); + memset(conv_buffer_data + right_start, byte_zero, + (right_padding * in_depth * sizeof(T))); + } + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } + + // If the bottom of the patch falls off the input image, pad the values + // representing those input rows with zeroes. + if (bottom_padding > 0) { + const int bottom_row_elements = (bottom_padding * kwidth * in_depth); + const int bottom_start = + output_row_offset + + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth); + memset(conv_buffer_data + bottom_start, byte_zero, + (bottom_row_elements * sizeof(T))); + } +} + +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, + int stride_height, int pad_width, int pad_height, int kheight, + int kwidth, uint8 byte_zero, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Im2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + int buffer_id = 0; + // Loop over the output nodes. + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < output_height; ++h) { + for (int w = 0; w < output_width; ++w) { + ExtractPatchIntoBufferColumn( + input_dims, w, h, b, kheight, kwidth, stride_width, stride_height, + pad_width, pad_height, input_width, input_height, input_depth, + output_depth, buffer_id, input_data, output_data, byte_zero); + ++buffer_id; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; + (void)im2col_dims; + gemmlowp::ScopedProfilingLabel label("Conv"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, 0, im2col_data, + im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + // TODO(aselle): We need to make sure to not send im2col if it is not + // needed. + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("Conv/8bit"); + + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + const uint8* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, input_zero_point, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const int gemm_input_rows = gemm_input_dims->sizes[0]; + const int gemm_input_cols = gemm_input_dims->sizes[1] * + gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); + TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, filter_rows, filter_cols); + gemmlowp::MatrixMap input_matrix( + gemm_input_data, gemm_input_rows, gemm_input_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + + const int output_depth = ArraySize(output_dims, 0); + const int batch_size = ArraySize(output_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm"); + + const auto input_matrix_map = + MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit"); + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + const int input_rows = input_dims.sizes[0]; + const int input_cols = + input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, input_cols); + TFLITE_DCHECK_EQ(filter_cols, input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, output_cols, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + const int input_depth = ArraySize(input_dims, 0); + const int batch_size = ArraySize(input_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * input_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int out_h = 0; out_h < output_height; ++out_h) { + T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* dst = output_ptr; + for (int out_w = 0; out_w < output_width; ++out_w) { + memcpy(dst, input_data, stride * sizeof(T)); + input_data += stride; + dst += output_depth; + } + output_ptr += stride; + } + } + } +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); + + const auto input = MapAsVector(input_data, input_dims); + auto output = MapAsVector(output_data, output_dims); + output = input.cwiseMax(0.0f); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization"); + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[i] = static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vaddq_f32(a10, a20); + auto x1 = vaddq_f32(a11, a21); + auto x2 = vaddq_f32(a12, a22); + auto x3 = vaddq_f32(a13, a23); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vaddq_f32(a1, a2); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] + input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + gemmlowp::ScopedProfilingLabel label("Add/8bit"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; + TFLITE_DCHECK_GT(input1_offset, -256); + TFLITE_DCHECK_GT(input2_offset, -256); + TFLITE_DCHECK_LT(input1_offset, 256); + TFLITE_DCHECK_LT(input2_offset, 256); +#ifdef USE_NEON + for (; i <= size - 8; i += 8) { + const auto input1_val_original = vld1_u8(input1_data + i); + const auto input2_val_original = vld1_u8(input2_data + i); + const auto input1_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input1_val_original)); + const auto input2_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input2_val_original)); + const auto input1_val = + vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset)); + const auto input2_val = + vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset)); + const auto input1_val_high = vget_high_s16(input1_val); + const auto input1_val_low = vget_low_s16(input1_val); + const auto input2_val_high = vget_high_s16(input2_val); + const auto input2_val_low = vget_low_s16(input2_val); + auto x11 = vmovl_s16(input1_val_low); + auto x12 = vmovl_s16(input1_val_high); + auto x21 = vmovl_s16(input2_val_low); + auto x22 = vmovl_s16(input2_val_high); + const auto left_shift_dup = vdupq_n_s32(left_shift); + x11 = vshlq_s32(x11, left_shift_dup); + x12 = vshlq_s32(x12, left_shift_dup); + x21 = vshlq_s32(x21, left_shift_dup); + x22 = vshlq_s32(x22, left_shift_dup); + x11 = vqrdmulhq_n_s32(x11, input1_multiplier); + x12 = vqrdmulhq_n_s32(x12, input1_multiplier); + x21 = vqrdmulhq_n_s32(x21, input2_multiplier); + x22 = vqrdmulhq_n_s32(x22, input2_multiplier); + const auto input1_shift_dup = vdupq_n_s32(-input1_shift); + const auto input2_shift_dup = vdupq_n_s32(-input2_shift); + x11 = vshlq_s32(x11, input1_shift_dup); + x12 = vshlq_s32(x12, input1_shift_dup); + x21 = vshlq_s32(x21, input2_shift_dup); + x22 = vshlq_s32(x22, input2_shift_dup); + auto s1 = vaddq_s32(x11, x21); + auto s2 = vaddq_s32(x12, x22); + s1 = vqrdmulhq_n_s32(s1, output_multiplier); + s2 = vqrdmulhq_n_s32(s2, output_multiplier); + using gemmlowp::RoundingDivideByPOT; + s1 = RoundingDivideByPOT(s1, output_shift); + s2 = RoundingDivideByPOT(s2, output_shift); + const auto s1_narrowed = vmovn_s32(s1); + const auto s2_narrowed = vmovn_s32(s2); + const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), + vdupq_n_s16(output_offset)); + vst1_u8(output_data + i, vqmovun_s16(s)); + } +#endif // NEON + + for (; i < size; i++) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); + } +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() + input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() + scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar + input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vmulq_f32(a10, a20); + auto x1 = vmulq_f32(a11, a21); + auto x2 = vmulq_f32(a12, a22); + auto x3 = vmulq_f32(a13, a23); + + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vmulq_f32(a1, a2); + + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] * input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() * input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() * scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar * input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastMul is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Concatenation"); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // for now we dont have a model with a Concatenation + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + gemmlowp::ScopedProfilingLabel label("LstmCell"); + MatchingArraySize( // batches + input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims, + 3, output_activ_dims, 3); + MatchingArraySize( // height + input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims, + 2, output_activ_dims, 2); + MatchingArraySize( // width + input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims, + 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Map raw arrays to Eigen arrays so we can use Eigen's optimized array + // operations. + ArrayMap activ_temp_map = + MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims); + auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth, + activ_temp_map.cols()); + ArrayMap prev_state_map = + MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims); + ArrayMap output_state_map = + MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims); + ArrayMap output_activ_map = + MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims); + + // Combined memory state and final output calculation + gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); + output_state_map = + input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + new_input_sm.tanh() + + forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + prev_state_map; + output_activ_map = + output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + output_state_map.tanh(); +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowSplit"); + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + const int whb = width * height * batches; + const Scalar* input_ptr = input_data; + for (int k = 0; k < whb; k++) { + for (int i = 0; i < outputs_count; ++i) { + memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr, + output_dims[i]->sizes[0] * sizeof(Scalar)); + input_ptr += output_dims[i]->sizes[0]; + } + } +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + // TODO(benoitjacob) make this a proper reference impl without Eigen! + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // TODO(benoitjacob) get rid of the dynamic memory allocation here! + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) += + in_mat.col(NodeOffset(b, h, w, input_height, input_width)); + out_count(out_offset)++; + } + } + } + } + } + // Divide the output by the actual number of elements being averaged over + TFLITE_DCHECK_GT(out_count.minCoeff(), 0); + out_mat.array().rowwise() /= out_count.transpose().array(); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + const int filter_count = + (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); + // 1280 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint16 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint16x8_t acc_reg[2]; + for (int i = 0; i < 2; i++) { + acc_reg[i] = vld1q_u16(acc + channel + 8 * i); + } + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg)); + acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg)); + for (int i = 0; i < 2; i++) { + vst1q_u16(acc + channel + 8 * i, acc_reg[i]); + } + } + for (; channel <= depth - 8; channel += 8) { + uint16x8_t acc_reg = vld1q_u16(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vaddw_u8(acc_reg, input_reg); + vst1q_u16(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] += *input_row_ptr++; + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON +#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ + if (filter_count == FILTER_COUNT) { \ + for (; channel <= depth - 8; channel += 8) { \ + uint16 buf[8]; \ + for (int i = 0; i < 8; i++) { \ + buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ + } \ + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \ + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \ + vst1_u8(output_ptr + channel, buf8); \ + } \ + } + AVGPOOL_DIVIDING_BY(9) + AVGPOOL_DIVIDING_BY(15) +#undef AVGPOOL_DIVIDING_BY + for (; channel <= depth - 8; channel += 8) { + uint16 buf[8]; + for (int i = 0; i < 8; i++) { + buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; + } + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, buf8); + } +#endif + for (; channel < depth; ++channel) { + uint16 a = (acc[channel] + filter_count / 2) / filter_count; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Prefill the output to minimum representable float value + out_mat.setConstant(std::numeric_limits::lowest()); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) = + out_mat.col(out_offset) + .cwiseMax(in_mat.col( + NodeOffset(b, h, w, input_height, input_width))); + } + } + } + } + } + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + // 2048 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint8 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t acc_reg = vld1q_u8(acc + channel); + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg = vmaxq_u8(acc_reg, input_reg); + vst1q_u8(acc + channel, acc_reg); + } + + for (; channel <= depth - 8; channel += 8) { + uint8x8_t acc_reg = vld1_u8(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vmax_u8(acc_reg, input_reg); + vst1_u8(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] = std::max(acc[channel], *input_row_ptr++); + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t a = vld1q_u8(acc + channel); + a = vminq_u8(a, vdupq_n_u8(output_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(output_activation_min)); + vst1q_u8(output_ptr + channel, a); + } + for (; channel <= depth - 8; channel += 8) { + uint8x8_t a = vld1_u8(acc + channel); + a = vmin_u8(a, vdup_n_u8(output_activation_max)); + a = vmax_u8(a, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, a); + } +#endif + for (; channel < depth; ++channel) { + uint8 a = acc[channel]; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Pool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + // Actually carry out L2 Pool. Code is written in forward mode: we go through + // the input values once, and write to all the pooled regions that it maps to. + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + Eigen::VectorXf in_square(in_mat.rows()); + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + pad_height; + const int wpad = w + pad_width; + const int h_start = (hpad < filter_height) + ? 0 + : (hpad - filter_height) / stride_height + 1; + const int h_end = std::min(hpad / stride_height + 1, output_height); + const int w_start = (wpad < filter_width) + ? 0 + : (wpad - filter_width) / stride_width + 1; + const int w_end = std::min(wpad / stride_width + 1, output_width); + // pre-compute square + const int in_offset = w + input_width * (h + input_height * b); + in_square = + in_mat.col(in_offset).array() * in_mat.col(in_offset).array(); + // compute elementwise sum of squares + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = pw + output_width * (ph + output_height * b); + out_mat.col(out_offset) += in_square; + out_count(out_offset)++; + } + } + } + } + } + + out_count = out_count.array().inverse(); + out_mat = + (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + // Carry out local response normalization, vector by vector. + // Since the data are stored column major, making row-wise operation + // probably not memory efficient anyway, we do an explicit for loop over + // the columns. + const int double_range = range * 2; + Eigen::VectorXf padded_square(data_in.rows() + double_range); + padded_square.setZero(); + for (int r = 0; r < data_in.cols(); ++r) { + // Do local response normalization for data_in(:, r) + // first, compute the square and store them in buffer for repeated use + padded_square.block(range, 0, data_in.rows(), 1) = + data_in.col(r).cwiseProduct(data_in.col(r)) * alpha; + // Then, compute the scale and writes them to data_out + float accumulated_scale = 0; + for (int i = 0; i < double_range; ++i) { + accumulated_scale += padded_square(i); + } + for (int i = 0; i < data_in.rows(); ++i) { + accumulated_scale += padded_square(i + double_range); + data_out(i, r) = bias + accumulated_scale; + accumulated_scale -= padded_square(i); + } + } + + // In a few cases, the pow computation could benefit from speedups. + if (beta == 1) { + data_out.array() = data_in.array() * data_out.array().inverse(); + } else if (beta == 0.5) { + data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + } else { + data_out.array() = data_in.array() * data_out.array().pow(-beta); + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Softmax"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Compute the exponential first, removing the max coefficient for numerical + // stability. + out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; + // We are separating out the exp function so that exp can be vectorized. + out_mat = out_mat.array().exp(); + // Normalize to get the activations. + Eigen::Array scale = + out_mat.array().colwise().sum().inverse(); + out_mat.array().rowwise() *= scale; +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + gemmlowp::ScopedProfilingLabel label("Softmax"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead. + int headroom_plus_one = + __builtin_clz(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = + std::max(std::min(unsat_output, 255), 0); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = + input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3); + + // Divide by 2^23 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[c] = output_val; + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Tanh"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().tanh(); +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Dequantize"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FakeQuant"); + + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Cast"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().template cast(); +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Floor"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = Eigen::floor(input_map.array()); +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Gather"); + + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +#ifdef USE_NEON +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + int ic = 0; + // Handle 32 input channels at a time. + for (; ic <= depth - 32; ic += 32) { + float32x4x2_t input[4]; + for (int i = 0; i < 4; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 4; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 32; + } + // Handle 16 input channels at a time. + for (; ic <= depth - 16; ic += 16) { + float32x4x2_t input[2]; + for (int i = 0; i < 2; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 2; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + float32x4x2_t input; + input.val[0] = vld1q_f32(input_ptr); + input.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t acc; + acc.val[0] = vld1q_f32(output_ptr); + acc.val[1] = vld1q_f32(output_ptr + 4); + acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale); + acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale); + + vst1q_f32(output_ptr, acc.val[0]); + vst1q_f32(output_ptr + 4, acc.val[1]); + + input_ptr += 8; + output_ptr += 8; + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + float32x4_t input = vld1q_f32(input_ptr); + float32x4_t acc = vld1q_f32(output_ptr); + + acc = vmlaq_n_f32(acc, input, scale); + vst1q_f32(output_ptr, acc); + + input_ptr += 4; + output_ptr += 4; + } + // Handle 1 input channel at a time. + for (; ic < depth; ic++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#else +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + for (int32 i = 0; i < depth; i++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#endif + +inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, + int32 x, int32 y, int32 depth, int32 batch, + const float* input_data, + const Dims<4>& input_dims, + float* output_data, + const Dims<4>& output_dims) { + const int32 input_width = ArraySize(input_dims, 1); + const int32 output_width = ArraySize(output_dims, 1); + + const int32 input_x_offset = (x1 - x0) * depth; + const int32 input_y_offset = (y1 - y0) * depth * input_width; + const int32 output_x_offset = depth; + const int32 output_y_offset = depth * output_width; + +#ifdef USE_NEON + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(x1 >= x0); + TFLITE_DCHECK(y1 >= y0); + + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + const float* input_ptr = nullptr; + + float32x4x2_t x0y0; + input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + x0y0.val[0] = vld1q_f32(input_ptr); + x0y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y0; + input_ptr += input_x_offset; + x1y0.val[0] = vld1q_f32(input_ptr); + x1y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x0y1; + input_ptr += -input_x_offset + input_y_offset; + x0y1.val[0] = vld1q_f32(input_ptr); + x0y1.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y1; + input_ptr += input_x_offset; + x1y1.val[0] = vld1q_f32(input_ptr); + x1y1.val[1] = vld1q_f32(input_ptr + 4); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0.val[0]); + vst1q_f32(output_ptr + 4, x0y0.val[1]); + + // Top right corner. + output_ptr += output_x_offset; + float32x4x2_t tr; + tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]); + tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]); + tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f); + tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f); + + vst1q_f32(output_ptr, tr.val[0]); + vst1q_f32(output_ptr + 4, tr.val[1]); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4x2_t bl; + bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]); + bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]); + bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f); + bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f); + vst1q_f32(output_ptr, bl.val[0]); + vst1q_f32(output_ptr + 4, bl.val[1]); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4x2_t br; + br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]); + br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]); + br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f); + br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f); + br.val[0] = vmulq_n_f32(br.val[0], 0.5f); + br.val[1] = vmulq_n_f32(br.val[1], 0.5f); + vst1q_f32(output_ptr, br.val[0]); + vst1q_f32(output_ptr + 4, br.val[1]); + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + float32x4_t x0y0 = vld1q_f32(input_ptr); + float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset); + float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset); + float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0); + + // Top right corner. + output_ptr += output_x_offset; + float32x4_t tr = vaddq_f32(x0y0, x1y0); + tr = vmulq_n_f32(tr, 0.5f); + vst1q_f32(output_ptr, tr); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4_t bl = vaddq_f32(x0y0, x0y1); + bl = vmulq_n_f32(bl, 0.5f); + vst1q_f32(output_ptr, bl); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4_t br = vaddq_f32(x1y0, x1y1); + br = vmlaq_n_f32(bl, br, 0.5f); + br = vmulq_n_f32(br, 0.5f); + vst1q_f32(output_ptr, br); + } + // Handle one input channel at a time. + for (; ic < depth; ic++) { + const int32 input_offset = Offset(input_dims, ic, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ic, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#else + for (int ch = 0; ch < depth; ch++) { + const int32 input_offset = Offset(input_dims, ch, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ch, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#endif +} + +inline void ResizeBilinear2x2(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width) { + for (int b = 0; b < batches; b++) { + for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) { + for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) { + int32 x1 = std::min(x0 + 1, input_width - 1); + int32 y1 = std::min(y0 + 1, input_height - 1); + ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data, + input_dims, output_data, output_dims); + } + } + } +} + +inline void ResizeBilinearGeneric(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width, float height_scale, + float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(float)); + + int32 output_offset = 0; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + float* output_ptr = &output_data[output_offset]; + + // Run kernel on the 4 corners of the bilinear resize algorithm. + int32 input_offset = Offset(input_dims, 0, x0, y0, b); + float scale = (1 - (input_y - y0)) * (1 - (input_x - x0)); + const float* input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y0, b); + scale = (1 - (input_y - y0)) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x0, y1, b); + scale = (input_y - y0) * (1 - (input_x - x0)); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y1, b); + scale = (input_y - y0) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + output_offset += depth; + } + } + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + // Specialize for 2x2 upsample. + if (output_height == 2 * input_height && output_width == 2 * input_width) { + ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, + input_height, input_width, depth, output_height, + output_width); + } else { + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, + batches, input_height, input_width, depth, + output_height, output_width, height_scale, + width_scale); + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Pad"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const int input_depth = ArraySize(input_dims, 0); + + if (left_b_padding != 0) { + memset(output_data, 0, + left_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } + for (int out_b = left_b_padding; out_b < output_batch - right_b_padding; + ++out_b) { + if (left_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0, + left_h_padding * output_width * output_depth * sizeof(T)); + } + for (int out_h = left_h_padding; out_h < output_height - right_h_padding; + ++out_h) { + if (left_w_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0, + left_w_padding * output_depth * sizeof(T)); + } + for (int out_w = left_w_padding; out_w < output_width - right_w_padding; + ++out_w) { + if (left_d_padding != 0) { + memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0, + left_d_padding * sizeof(T)); + } + + T* out = output_data + + Offset(output_dims, left_d_padding, out_w, out_h, out_b); + const T* in = + input_data + Offset(input_dims, 0, out_w - left_w_padding, + out_h - left_h_padding, out_b - left_b_padding); + memcpy(out, in, input_depth * sizeof(T)); + + if (right_d_padding != 0) { + memset( + output_data + Offset(output_dims, output_depth - right_d_padding, + out_w, out_h, out_b), + 0, right_d_padding * sizeof(T)); + } + } + if (right_w_padding != 0) { + memset( + output_data + Offset(output_dims, 0, output_width - right_w_padding, + out_h, out_b), + 0, right_w_padding * output_depth * sizeof(T)); + } + } + if (right_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, + output_height - right_h_padding, out_b), + 0, right_h_padding * output_width * output_depth * sizeof(T)); + } + } + if (right_b_padding != 0) { + memset(output_data + + Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), + 0, + right_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("StridedSlice"); + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + if (strides[0] == 0) { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } + } else { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mean"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Sub"); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() - input2_map.array(); + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar - input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() - scalar; + } else { + GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims, + output_data, output_dims); + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto min_value = input2_data[0]; + output_map.array() = input1_map.array().min(min_value); +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto max_value = input2_data[0]; + output_map.array() = input1_map.array().max(max_value); +} +} // namespace optimized_ops +} // namespace tflite + +#if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic pop +#endif + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h new file mode 100644 index 0000000000..f8be99e82f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +namespace tflite { +namespace tensor_utils { + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC operation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); +void NeonSub1Vector(const float* vector, int v_size, float* result); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Limit a float input f between +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +} // namespace tensor_utils +} // namespace tflite + +#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc new file mode 100644 index 0000000000..98f2e365c5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK(double_multiplier >= 0.); + TFLITE_CHECK(double_multiplier < 1.); + if (double_multiplier == 0.) { + *quantized_multiplier = 0; + *right_shift = 0; + return; + } + TFLITE_CHECK(double_multiplier > 0.); + const double q = std::frexp(double_multiplier, right_shift); + *right_shift *= -1; + + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + --*right_shift; + } + TFLITE_CHECK_GE(*right_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { + TFLITE_CHECK(double_multiplier > 1.); + const double q = std::frexp(double_multiplier, left_shift); + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*left_shift; + } + TFLITE_CHECK_GE(*left_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift) { + // If the overall multiplier (input and beta) is large, then exp() of an + // input difference of 1 scaled by this will be large. In other words, we + // can cap the multiplier and know that, when it is used, the output will be + // (round to) zero wherever the input is not at the maximum value. + + // If the overall scale is less than one, and input_integer_bits=0, then the + // result is double equivalent of Q0.31 (actually with more precision). Thus + // this generates a Q(input_integer_bits).(31-input_integer_bits) + // representation. + const double input_beta_real_multiplier = std::min( + beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0); + + QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, + quantized_multiplier, left_shift); +} + +int CalculateInputRadius(int input_integer_bits, int input_left_shift) { + const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * + (1ll << (31 - input_integer_bits)) / + (1ll << input_left_shift); + // Tighten bound using floor. Suppose that we could use the exact value. + // After scaling the difference, the result would be at the maximum. Thus we + // must ensure that our value has lower magnitude. + return static_cast(std::floor(max_input_rescaled)); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h new file mode 100644 index 0000000000..efb7191c8d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ +#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ + +#include + +namespace tflite { + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier < 1 (and non-negative). +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift); + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier > 1. +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); + +// This first creates a multiplier in a double equivalent of +// Q(input_integer_bits).(31-input_integer_bits) representation, with extra +// precision in the double's fractional bits. It then splits the result into +// significand and exponent. +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift); + +// Calculate the largest input that will result in a within-bounds intermediate +// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, +// it must not overflow before we reduce the value by multiplication by the +// input multiplier. The negative radius is used as the minimum difference +// in Softmax. +int CalculateInputRadius(int input_integer_bits, int input_left_shift); + +} // namespace tflite + +#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc new file mode 100644 index 0000000000..d6f306e2cb --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" + +#include +#include + +namespace tflite { +namespace { + +using ::testing::Pair; + +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierSmallerThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + EXPECT_DEATH(quantize(-0.1), ""); + EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + + // Around 0.5 we can see the change in exponent and how we try hard to + // void hitting max int32. + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); + EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); + + EXPECT_THAT(quantize(0.75), Pair(1610612736, 0)); + EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0)); + + // If we get close enough to 1.0 it crashes and dies in one of two ways: + // Either the shift becomes negative or we trigger the 'less-than-one' CHECK. + EXPECT_DEATH(quantize(1 - 1e-15), ""); + EXPECT_DEATH(quantize(1 - 1e-17), ""); + EXPECT_DEATH(quantize(1.0), ""); +} + +TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierGreaterThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + // If we are close enough to 1.0 it crashes. + EXPECT_DEATH(quantize(1 + 1e-16), ""); + + EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1)); + EXPECT_THAT(quantize(1.25), Pair(1342177280, 1)); + EXPECT_THAT(quantize(1.50), Pair(1610612736, 1)); + EXPECT_THAT(quantize(1.75), Pair(1879048192, 1)); + + // Around the powers of two we see the change in exponent. Also, + // we try hard to avoid hitting max int32. + EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1)); + EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2)); + EXPECT_THAT(quantize(2), Pair(1073741824, 2)); +} + +TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) { + auto quantize = [](double beta, double scale, int integer_bits) { + int32_t q; + int s; + PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s); + return std::pair{q, s}; + }; + + // If beta * scale is greater than fits in the number of integer bits, the + // result is move near the maximum. Otherwise they quantize as expected. + // With 4 integer bits we can represent up to 16.0. + EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31)); + EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31)); + // But with 5 bits we can go further. + EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31)); + EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31)); +} + +TEST(QuantizationUtilTest, CalculateInputRadius) { + EXPECT_EQ(CalculateInputRadius(4, 27), 15); + EXPECT_EQ(CalculateInputRadius(3, 27), 14); + EXPECT_EQ(CalculateInputRadius(3, 28), 7); + EXPECT_EQ(CalculateInputRadius(4, 2), 503316480); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h new file mode 100644 index 0000000000..8e0f234545 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + float filter_value = filter_data[Offset( + filter_dims, oc, filter_x, filter_y, 0)]; + total += (input_value * filter_value); + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h new file mode 100644 index 0000000000..8a80558b32 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ + +#include + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + int32 filter_val = filter_data[Offset(filter_dims, oc, + filter_x, filter_y, 0)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + static_cast(acc); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc new file mode 100644 index 0000000000..c5b0bccc9d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace tensor_utils { + +float PortableClip(float f, float abs_limit) { + float result = (abs_limit < f) ? abs_limit : f; + result = (-abs_limit > result) ? -abs_limit : result; + return result; +} + +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride) { + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + for (int r = 0; r < m_rows; r++) { + const float* vector_in_batch = vector + b * m_cols; + for (int c = 0; c < m_cols; c++) { + *result_in_batch += *matrix_ptr++ * *vector_in_batch++; + } + result_in_batch += result_stride; + } + } +} + +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = *vector1++ * *vector2++; + } +} + +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + float result = 0.0; + for (int v = 0; v < v_size; v++) { + result += *vector1++ * *vector2++; + } + return result; +} + +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = + PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ += *vector1++ * *vector2++; + } +} + +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result) { + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < v_size; v++) { + *result++ += vector[v] * *batch_vector++; + } + } +} + +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector) { + for (int b = 0; b < n_batch; b++) { + memcpy(batch_vector + b * v_size, vector, v_size * sizeof(float)); + } +} + +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result) { + auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid); + for (int v = 0; v < v_size; v++) { + *result++ = (sigmoid_func)(*vector++); + } +} + +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result) { + auto activation_func = ActivationFunctor(activation); + for (int v = 0; v < v_size; v++) { + *result++ = (activation_func)(*vector++); + } +} + +void PortableCopyVector(const float* vector, int v_size, float* result) { + memcpy(result, vector, v_size * sizeof(float)); +} + +void PortableSub1Vector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = 1.0f - *vector++; + } +} + +void PortableZeroVector(float* vector, int v_size) { + memset(vector, 0, v_size * sizeof(float)); +} + +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = PortableClip(*vector++, abs_limit); + } +} + +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) { + TF_LITE_ASSERT(v_size > 0); + for (int i = 0; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + for (int r = 0; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h new file mode 100644 index 0000000000..c2ab78000b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + PortableVectorBatchVectorCwiseProductAccumulate(vector, v_size, batch_vector, + n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return PortableVectorVectorDotProduct(vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch, + result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + PortableSub1Vector(vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + PortableClipVector(vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + PortableVectorShiftLeft(vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h new file mode 100644 index 0000000000..b9ca3d5c62 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -0,0 +1,2455 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + if (bias_data) { + TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); + } + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + total += (input_value * filter_value); + } + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = + MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + int32 filter_val = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims, im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width * block_size, output_width); + TFLITE_DCHECK_EQ(input_height * block_size, output_height); + TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + const int in_d = + out_d + ((out_h % block_size) * block_size + out_w % block_size) * + output_depth; + const int in_w = out_w / block_size; + const int in_h = out_h / block_size; + const int in_b = out_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width, output_width * block_size); + TFLITE_DCHECK_EQ(input_height, output_height * block_size); + TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int in_b = 0; in_b < input_batch; ++in_b) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + for (int in_d = 0; in_d < input_depth; ++in_d) { + const int out_d = + in_d + ((in_h % block_size) * block_size + in_w % block_size) * + input_depth; + const int out_w = in_w / block_size; + const int out_h = in_h / block_size; + const int out_b = in_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(weights_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + float total = 0.f; + for (int d = 0; d < accum_depth; ++d) { + total += input_data[b * accum_depth + d] * + weights_data[out_c * accum_depth + d]; + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax( + total + bias_value, output_activation_min, output_activation_max); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(filter_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32 acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32 input_val = input_data[b * accum_depth + d]; + int32 filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float lower = 0; + float clamped = val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float l2_norm = std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] / l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[Offset(output_dims, i, 0, 0, 0)] = + static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] + + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int32 input1_val = + input1_offset + input1_data[Offset(input1_dims, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[Offset(input2_dims, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] * + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_GT(inputs_count, 1); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Memory state update (the LSTM "guts") + for (int b = 0; b < batches; ++b) { + for (int w = 0; w < width; ++w) { + for (int h = 0; h < height; ++h) { + for (int c = 0; c < output_depth; ++c) { + const float input_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 0 * output_depth + c, w, h, b)])); + const float new_input = std::tanh(activ_temp_data[Offset( + activ_temp_dims, 1 * output_depth + c, w, h, b)]); + const float forget_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 2 * output_depth + c, w, h, b)])); + const float output_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 3 * output_depth + c, w, h, b)])); + const float new_state = + input_gate * new_input + + forget_gate * + prev_state_data[Offset(prev_state_dims, c, w, h, b)]; + output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state; + output_activ_data[Offset(output_activ_dims, c, w, h, b)] = + output_gate * std::tanh(new_state); + } + } + } + } +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + int in_c = 0; + for (int i = 0; i < outputs_count; ++i) { + const int depth = ArraySize(*output_dims[i], 0); + for (int c = 0; c < depth; ++c) { + output_data[i][Offset(*output_dims[i], c, x, y, b)] = + input_data[Offset(input_dims, in_c, x, y, b)]; + in_c++; + } + } + TFLITE_DCHECK(in_c == ArraySize(input_dims, 0)); + } + } + } +} + +// TODO(benoitjacob) make this a proper reference impl without Eigen! +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float total = 0.f; + float filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + total += + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + const float average = total / filter_count; + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(average, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + int32 acc = 0; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + acc = (acc + filter_count / 2) / filter_count; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float sum_squares = 0.f; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + const float val = + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + sum_squares += val * val; + filter_count++; + } + } + const float l2pool_result = std::sqrt(sum_squares / filter_count); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(l2pool_result, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float max = std::numeric_limits::lowest(); + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(max, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LE(output_activation_max, 255); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + uint8 max = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + max = std::max(max, output_activation_min); + max = std::min(max, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int begin_input_c = std::max(0, c - range); + const int end_input_c = std::min(depth, c + range); + float accum = 0.f; + for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) { + const float input_val = + input_data[Offset(input_dims, input_c, x, y, b)]; + accum += input_val * input_val; + } + const float multiplier = std::pow(bias + alpha * accum, -beta); + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * multiplier; + } + } + } + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C)) + float max = std::numeric_limits::lowest(); + for (int c = 0; c < depth; ++c) { + max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); + } + + // Compute sum. + float sum = 0.f; + for (int c = 0; c < depth; ++c) { + sum += std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta); + } + + // Compute result. + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta) / + sum; + } + } + } + } +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + int headroom_plus_one = + CountLeadingZeros(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = static_cast( + std::max(std::min(unsat_output, static_cast(255)), 0)); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = 1.f / (1.f + std::exp(-val)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = std::tanh(val); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = static_cast(input_data[offset]); + } + } + } + } +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = std::floor(input_data[offset]); + } + } + } + } +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(std::floor(input_x)); + int32 x1 = std::min(x0 + 1, input_width - 1); + for (int c = 0; c < depth; ++c) { + float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * + (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0); + output_data[Offset(output_dims, c, x, y, b)] = interpolation; + } + } + } + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const T* in_ptr = input_data; + T* out_ptr = output_data; + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + if (out_b < left_b_padding || + out_b >= output_batch - right_b_padding || + out_h < left_h_padding || + out_h >= output_height - right_h_padding || + out_w < left_w_padding || + out_w >= output_width - right_w_padding || + out_d < left_d_padding || + out_d >= output_depth - right_d_padding) { + *out_ptr++ = 0; + } else { + *out_ptr++ = *in_ptr++; + } + } + } + } + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + for (int in_d = start_d; in_d < stop_d; ++in_d) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto min_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] > min_value ? min_value : input1_data[offset]; + } + } + } + } +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto max_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] < max_value ? max_value : input1_data[offset]; + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h new file mode 100644 index 0000000000..38525b0e20 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/round.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ + +#include + +namespace tflite { + +// TODO(aselle): See if we can do this only on jdk. Also mikecase, check +// if you need this for java host build. +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float TfLiteRound(const float x) { + return ::round(x); +} +inline double TfLiteRound(const double x) { return ::round(x); } +#else +template +inline T TfLiteRound(const T x) { + return std::round(x); +} +#endif + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h new file mode 100644 index 0000000000..ee4111e041 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +template +inline T* GetTensorData(TfLiteTensor* tensor); + +template <> +inline float* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline uint8_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline int32_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline int64_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? reinterpret_cast(tensor->data.raw) + : nullptr; +} + +inline int RemapDim(int max_dimensions, int d) { + return max_dimensions - d - 1; +} + +// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object +// even if the original tensors were not 4D. We should consider rewriting them +// to take a more generic 'shape' object. +inline Dims<4> GetTensorDims(const int data[], const int size) { + Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = size - i - 1; + if (src >= 0) { + d.sizes[i] = data[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; +} + +inline Dims<4> GetTensorDims(std::vector data) { + return GetTensorDims(data.data(), data.size()); +} + +inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return Dims<4>(); + } + + auto* dims = tensor->dims; + return GetTensorDims(dims->data, dims->size); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc new file mode 100644 index 0000000000..bf2068d320 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include +#include + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +TEST(TensorTest, GetTensorDims4D) { + Dims<4> d = GetTensorDims({2, 3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims3D) { + Dims<4> d = GetTensorDims({3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims2D) { + Dims<4> d = GetTensorDims({4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20)); +} + +TEST(TensorTest, GetTensorDims1D) { + Dims<4> d = GetTensorDims({5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc new file mode 100644 index 0000000000..904a97803a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +#ifdef USE_NEON +#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h" +#else +#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h" +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h new file mode 100644 index 0000000000..0e69ef5982 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float Clip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector using a stride value provided in result_stride. 'result_stride' shows +// how the number of elements between consecutive result values. For example +// result_stride = 1, will cause the output to look like this: +// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be +// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows] +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors of size n_batch * v_size: +// vector1 = [x_1_1, x_1_2, ..., x_1_vsize, +// x_2_1, x_2_2, ..., x_2_vsize, +// ... +// x_nbatch_1,..., x_nbatch_vsize] +// vector2 = [y_1_1, y_1_2, ..., y_1_vsize, +// y_2_1, y_2_2, ..., y_2_vsize, +// ... +// y_nbatch_1,..., y_nbatch_vsize] +// Then result will be a vector of n_batch size which will be saved with a +// stride of result_stride in memory starting from 'result': +// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize, +// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize, +// ... +// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize] +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Batch vector initialization with another vector. +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector); + +// Apply sigmoid to elements of a vector. +void ApplySigmoidToVector(const float* vector, int v_size, float* result); + +// Apply activation function to elements of a vector. +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result); + +// Copy vector to another vector. +void CopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void Sub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void ZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void VectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc new file mode 100644 index 0000000000..588f1a428b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -0,0 +1,192 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace tensor_utils { + +TEST(uKernels, ClipTest) { + constexpr int kVectorSize = 10; + constexpr float kAbsLimit = 2.0; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + std::vector output(kVectorSize); + ClipVector(input, kVectorSize, kAbsLimit, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); +} + +TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) { + constexpr int kRow = 3; + constexpr int kCol = 4; + constexpr int kBatch = 2; + static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, // + -1.0, -2.0, -3.0, -4.0, // + 1.0, -2.0, 3.0, -4.0}; + static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, // + 2.0, -2.0, 2.0, -2.0}; + std::vector output(kRow * kBatch); + std::fill(output.begin(), output.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., // + -1., 7., 23.}))); + + std::vector output_with_stride2(kRow * kBatch * 2); + std::fill(output_with_stride2.begin(), output_with_stride2.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output_with_stride2.data(), + /*result_stride=*/2); + EXPECT_THAT(output_with_stride2, + ElementsAreArray(ArrayFloatNear({1., 3., 5., 3., 13., 3., // + -1., 3., 7., 3., 23., 3.}))); +} + +TEST(uKernels, VectorVectorCwiseProductTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45}))); +} + +TEST(uKernels, VectorVectorCwiseProductAccumulateTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + std::fill(output.begin(), output.end(), 1.0); + VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize, + output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45}))); +} + +TEST(uKernels, VectorBatchVectorAssignTest) { + constexpr int kVectorSize = 5; + constexpr int kBatchSize = 3; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize * kBatchSize); + VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, ApplySigmoidToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplySigmoidToVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.5, 0.377541, 0.731059, 0.182426, 0.880797}))); +} + +TEST(uKernels, ApplyActivationToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0}))); + + ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.462117, 0.761594, -0.905148, 0.964028}))); +} + +TEST(uKernels, CopyVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + CopyVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, Sub1VectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + Sub1Vector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0}))); +} + +TEST(uKernels, ZeroVectorTest) { + constexpr int kVectorSize = 5; + std::vector output(kVectorSize); + ZeroVector(output.data(), kVectorSize); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0}))); +} + +TEST(uKernels, BatchVectorBatchVectorDotProductTest) { + constexpr int kVectorSize = 5; + constexpr int kBatch = 2; + static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kBatch); + BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75}))); +} + +TEST(uKernels, VectorShiftLeftTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector result(kVectorSize); + VectorShiftLeft(input, kVectorSize, 3.0); + result.assign(input, input + kVectorSize); + EXPECT_THAT(result, + ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0}))); +} + +TEST(uKernels, ReductionSumVectorTest) { + constexpr int kInputVectorSize = 10; + constexpr int kOutputVectorSize1 = 5; + constexpr int kReductionSize1 = 2; + static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, 1.0, 2.0}; + std::vector result1(kOutputVectorSize1); + ReductionSumVector(input, result1.data(), kOutputVectorSize1, + kReductionSize1); + EXPECT_THAT(result1, + ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0}))); + + constexpr int kOutputVectorSize2 = 2; + constexpr int kReductionSize2 = 5; + std::vector result2(kOutputVectorSize2); + ReductionSumVector(input, result2.data(), kOutputVectorSize2, + kReductionSize2); + EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5}))); +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h new file mode 100644 index 0000000000..07f1cb4004 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { + +enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; + +template +struct Dims { + int sizes[N]; + int strides[N]; +}; + +inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]); + return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + + i3 * dims.strides[3]; +} + +// Get array size, DCHECKing that the dim index is in range. +template +int ArraySize(const Dims& array, int index) { + TFLITE_DCHECK(index >= 0 && index < N); + return array.sizes[index]; +} + +// Get common array size, DCHECKing that they all agree. +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return ArraySize(array1, index1); +} + +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2, Args... args) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return MatchingArraySize(array1, index1, args...); +} + +inline int RequiredBufferSizeForDims(const Dims<4>& dims) { + int max_offset = 0; + for (int i = 0; i < 4; i++) { + max_offset += (dims.sizes[i] - 1) * dims.strides[i]; + } + return max_offset + 1; +} + +template +bool IsPackedWithoutStrides(const Dims& dims) { + int expected_stride = 1; + for (int d = 0; d < N; d++) { + if (dims.strides[d] != expected_stride) return false; + expected_stride *= dims.sizes[d]; + } + return true; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc new file mode 100644 index 0000000000..b0546c00cf --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) { + const double input_product_scale = input->params.scale * filter->params.scale; + const double bias_scale = bias->params.scale; + const double output_scale = output->params.scale; + + // TODO(ahentz): The following conditions must be guaranteed by the training + // pipeline. + TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= + 1e-6 * std::min(input_product_scale, bias_scale)); + TF_LITE_ENSURE(context, input_product_scale >= 0); + TF_LITE_ENSURE(context, input_product_scale < output_scale); + + *multiplier = input_product_scale / output_scale; + + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + const auto scale = output->params.scale; + const auto zero_point = output->params.zero_point; + + auto quantize = [scale, zero_point](float f) { + return zero_point + static_cast(TfLiteRound(f / scale)); + }; + + if (activation == kTfLiteActRelu) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = qmax; + } else if (activation == kTfLiteActRelu6) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = std::min(qmax, quantize(6.0)); + } else if (activation == kTfLiteActRelu1) { + *act_min = std::max(qmin, quantize(-1.0)); + *act_max = std::min(qmax, quantize(1.0)); + } else { + *act_min = qmin; + *act_max = qmax; + } +} + +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0.f; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0.f; + *activation_max = 6.f; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1.f; + *activation_max = 1.f; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h new file mode 100644 index 0000000000..25556ae456 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } +inline int SizeOfDimension(const TfLiteTensor* t, int dim) { + return t->dims->data[dim]; +} +inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->inputs->data[index]]; +} +inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->outputs->data[index]]; +} +inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } +inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } + +inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, + const TfLiteNode* node, int index) { + const bool use_tensor = node->inputs->data[index] != kOptionalTensor; + if (use_tensor) { + return &context->tensors[node->inputs->data[index]]; + } + return nullptr; +} + +// Calculates the multiplication factor for a quantized convolution (or +// quantized depthwise convolution) involving the given tensors. Returns an +// error if the scales of the tensors are not compatible. +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier); + +// Calculates the useful range of an activation layer given its activation +// tensor. +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max); +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc new file mode 100644 index 0000000000..f43aa372b6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace l2norm { + +// This file has two implementation of L2Norm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + // TODO(ahentz): For some reason our implementations don't support + // activations. + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_L2NORM(type) \ + type::L2Normalization( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_L2NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_L2NORM(optimized_ops); + } +#undef TF_LITE_L2NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace l2norm + +TfLiteRegistration* Register_L2NORM_REF() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2NORM_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2_NORMALIZATION() { + return Register_L2NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc new file mode 100644 index 0000000000..b1db89b8bd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class L2NormOpModel : public SingleOpModel { + public: + L2NormOpModel(std::initializer_list input_shape, + ActivationFunctionType activation_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, activation_type).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(L2NormOpTest, SimpleTest) { + L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc new file mode 100644 index 0000000000..c1c70d0dfa --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace local_response_norm { + +// This file has two implementation of LocalResponseNorm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_LOCAL_RESPONSE_NORM(type) \ + type::LocalResponseNormalization( \ + GetTensorData(input), GetTensorDims(input), params->radius, \ + params->bias, params->alpha, params->beta, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_LOCAL_RESPONSE_NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_LOCAL_RESPONSE_NORM(optimized_ops); + } +#undef TF_LITE_LOCAL_RESPONSE_NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace local_response_norm + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION() { + return Register_LOCAL_RESPONSE_NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc new file mode 100644 index 0000000000..63a8b0a3d0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LocalResponseNormOpModel : public SingleOpModel { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(LocalResponseNormOpTest, SameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(LocalResponseNormOpTest, WithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(LocalResponseNormOpTest, WithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(LocalResponseNormOpTest, SmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc new file mode 100644 index 0000000000..5f73b56ed9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LSH Projection projects an input to a bit vector via locality senstive +// hashing. +// +// Options: +// Sparse: +// Computed bit vector is considered to be sparse. +// Each output element is an int32 made up by multiple bits computed from +// hash functions. +// +// Dense: +// Computed bit vector is considered to be dense. Each output element is +// either 0 or 1 that represents a bit. +// +// Input: +// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float. +// Tensor[0].Dim[0]: Num of hash functions. +// Tensor[0].Dim[1]: Num of projected output bits generated by +// each hash function. +// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32. +// +// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType. +// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float. +// If not set, each element of input is considered to have same +// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Sparse: +// Output.Dim == { Tensor[0].Dim[0] } +// A tensor of int32 that represents hash signatures, +// +// NOTE: To avoid collisions across hash functions, an offset value of +// k * (1 << Tensor[0].Dim[1]) will be added to each signature, +// k is the index of the hash function. +// Dense: +// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } +// A flattened tensor represents projected bit vectors. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include + +namespace tflite { +namespace ops { +namespace builtin { +namespace lsh_projection { + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* hash = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); + // Support up to 32 bits. + TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); + + TfLiteTensor* input = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + + if (NumInputs(node) == 3) { + TfLiteTensor* weight = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), + SizeOfDimension(input, 0)); + } + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + switch (params->type) { + case kTfLiteLshProjectionSparse: + outputSize->data[0] = SizeOfDimension(hash, 0); + break; + case kTfLiteLshProjectionDense: + outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1); + break; + default: + return kTfLiteError; + } + return context->ResizeTensor(context, output, outputSize); +} + +// Compute sign bit of dot product of hash(seed, input) and weight. +// NOTE: use float as seed, and convert it to double as a temporary solution +// to match the trained model. This is going to be changed once the new +// model is trained in an optimized method. +// +int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight, + float seed) { + double score = 0.0; + int input_item_bytes = input->bytes / SizeOfDimension(input, 0); + char* input_ptr = input->data.raw; + + const size_t seed_size = sizeof(float); + const size_t key_bytes = sizeof(float) + input_item_bytes; + std::unique_ptr key(new char[key_bytes]); + + for (int i = 0; i < SizeOfDimension(input, 0); ++i) { + // Create running hash id and value for current dimension. + memcpy(key.get(), &seed, seed_size); + memcpy(key.get() + seed_size, input_ptr, input_item_bytes); + + int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes); + double running_value = static_cast(hash_signature); + input_ptr += input_item_bytes; + if (weight == nullptr) { + score += running_value; + } else { + score += weight->data.f[i] * running_value; + } + } + + return (score > 0) ? 1 : 0; +} + +void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + int32_t hash_signature = 0; + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + hash_signature = (hash_signature << 1) | bit; + } + *out_buf++ = hash_signature + i * (1 << num_bits); + } +} + +void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + *out_buf++ = bit; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + int32_t* out_buf = GetOutput(context, node, 0)->data.i32; + TfLiteTensor* hash = GetInput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 1); + TfLiteTensor* weight = + NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); + + switch (params->type) { + case kTfLiteLshProjectionDense: + DenseLshProjection(hash, input, weight, out_buf); + break; + case kTfLiteLshProjectionSparse: + SparseLshProjection(hash, input, weight, out_buf); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace lsh_projection + +TfLiteRegistration* Register_LSH_PROJECTION() { + static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize, + lsh_projection::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc new file mode 100644 index 0000000000..1011927848 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LSHProjectionOpModel : public SingleOpModel { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(LSHProjectionOpTest2, Dense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(LSHProjectionOpTest2, Sparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(LSHProjectionOpTest2, Sparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc new file mode 100644 index 0000000000..6c06264d84 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -0,0 +1,515 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace lstm { + +// Input Tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 12; // Optional +constexpr int kForgetGateBiasTensor = 13; +constexpr int kCellGateBiasTensor = 14; +constexpr int kOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 17; // Optional + +// Output tensors. +constexpr int kScratchBufferTensor = 0; +constexpr int kOutputStateTensor = 1; +constexpr int kCellStateTensor = 2; +constexpr int kOutputTensor = 3; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + // Inferring batch size, number of outputs and number of cells from the + // input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = n_batch; + output_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); + output_state_size->data[0] = n_batch; + output_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, output_state_size)); + + // Resize the output, state and scratch buffer tensors. + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + if (use_cifg) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } else { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } + return kTfLiteOk; +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, + n_batch, output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state->data.f, n_batch * n_cell, + cell_state->data.f); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, + params->cell_clip, cell_state->data.f); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, + n_batch, output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights->data.f, n_output, n_cell, output_gate_scratch, + n_batch, output->data.f, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output->data.f, n_batch * n_output, + params->proj_clip, output->data.f); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output->data.f); + } + tensor_utils::CopyVector(output->data.f, n_batch * n_output, + output_state->data.f); + + return kTfLiteOk; +} + +} // namespace lstm + +TfLiteRegistration* Register_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + lstm::Prepare, lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc new file mode 100644 index 0000000000..be4c7ddbf8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -0,0 +1,1088 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, + -0.15358765, -0.03716109, 0.12507336, + 0.41193449, -0.20860538, -0.15053082, + 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, + -0.05163646, -0.42312205, -0.01218222, + 0.24201041, -0.08124574, -0.358325, + -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + + lstm.Invoke(); + + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc new file mode 100644 index 0000000000..81c73f2523 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace mul { + +// This file has three implementation of Mul. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_MUL(type) \ + type::Mul(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + + int32_t output_multiplier; + int output_shift; + + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_MUL(type) \ + type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalQuantized(context, node, params, input1, input2, output); + } else { + context->ReportError(context, + "Mul only supports FLOAT32 and quantized UINT8 now."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace mul + +TfLiteRegistration* Register_MUL_REF() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL() { +#ifdef USE_NEON + return Register_MUL_NEON_OPT(); +#else + return Register_MUL_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc new file mode 100644 index 0000000000..4b858e1f39 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseMulOpModel : public SingleOpModel { + public: + BaseMulOpModel(TensorData input, TensorData output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// For quantized Mul, the error shouldn't exceed (2*step + step^2). +// The param min=-1.0 & max=1.0 is used in the following tests. +// The tolerance value is ~0.0157. +const float kQuantizedStep = 2.0 / 255.0; +const float kQuantizedTolerance = + 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; + +class QuantizedMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatMulOpTest, NoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +TEST(FloatMulOpTest, ActivationRELU1) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0}))); +} + +TEST(FloatMulOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2}))) + << "With shape number " << i; + } +} + +TEST(QuantizedMulOpTest, NoActivation) { + QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h new file mode 100644 index 0000000000..7535afaf8e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ + +#define TF_LITE_FATAL(msg) \ + do { \ + fprintf(stderr, "%s\n", (msg)); \ + exit(1); \ + } while (0) +#define TF_LITE_ASSERT(x) \ + do { \ + if (!(x)) TF_LITE_FATAL(#x); \ + } while (0) +#define TF_LITE_ASSERT_EQ(x, y) \ + do { \ + if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ + } while (0) + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc new file mode 100644 index 0000000000..8977d27f73 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + void Verify() { + auto model = tflite::UnPackModel(builder_.GetBufferPointer()); + EXPECT_NE(model, nullptr); + } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + // Verify the model by unpacking it. + lstm.Verify(); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h new file mode 100644 index 0000000000..3a60274524 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ + +namespace tflite { + +inline int ComputePadding(int stride, int in_size, int filter_size, + int out_size) { + int padding = ((out_size - 1) * stride + filter_size - in_size) / 2; + return padding > 0 ? padding : 0; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc new file mode 100644 index 0000000000..b798801108 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pooling { + +// This file has two implementation of each pooling op. +enum KernelType { + kReference, + kGenericOptimized, +}; + +enum PoolType { + kAverage, + kMax, + kL2, +}; + +struct OpData { + TfLitePaddingValues padding; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +template +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + int batches = input->dims->data[0]; + int height = input->dims->data[1]; + int width = input->dims->data[2]; + int channels_out = input->dims->data[3]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = + computeOutSize(width, params->filter_width, params->stride_width); + int outHeight = + computeOutSize(height, params->filter_height, params->stride_height); + + data->padding.height = ComputePadding(params->stride_height, height, + params->filter_height, outHeight); + data->padding.width = ComputePadding(params->stride_width, width, + params->filter_width, outWidth); + + if (input->type == kTfLiteUInt8) { + if (pool_type == kAverage || pool_type == kMax) { + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + } + if (pool_type == kL2) { + // We currently don't have a quantized implementation of L2Pool + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + } + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = outHeight; + outputSize->data[2] = outWidth; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + type::L2Pool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_L2_POOL(reference_ops); + } else { + TF_LITE_L2_POOL(optimized_ops); + } +#undef TF_LITE_L2_POOL +} + +#undef TF_LITE_KERNEL_TYPE_DISPATCH + +template +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + AverageEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + AverageEvalQuantized(context, node, params, data, input, + output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + MaxEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + MaxEvalQuantized(context, node, params, data, input, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + L2EvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + // We don't have a quantized implementation, so just fall through to the + // 'default' case. + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_GENERIC_OPT() { + static TfLiteRegistration r = { + pooling::Init, pooling::Free, pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + return Register_AVERAGE_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + return Register_MAX_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_L2_POOL_2D() { + return Register_L2_POOL_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc new file mode 100644 index 0000000000..e1b51ec7d5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BasePoolingOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BasePoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatPoolingOpTest, AveragePool) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(QuantizedPoolingOpTest, AveragePool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92})); +} + +TEST(FloatPoolingOpTest, MaxPool) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(QuantizedPoolingOpTest, MaxPool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6, 10}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160})); +} + +TEST(FloatPoolingOpTest, L2Pool) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc new file mode 100644 index 0000000000..ca7a0dd194 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_RELU(); +TfLiteRegistration* Register_RELU1(); +TfLiteRegistration* Register_RELU6(); +TfLiteRegistration* Register_TANH(); +TfLiteRegistration* Register_LOGISTIC(); +TfLiteRegistration* Register_AVERAGE_POOL_2D(); +TfLiteRegistration* Register_MAX_POOL_2D(); +TfLiteRegistration* Register_L2_POOL_2D(); +TfLiteRegistration* Register_CONV_2D(); +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Register_SVDF(); +TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Register_LSH_PROJECTION(); +TfLiteRegistration* Register_HASHTABLE_LOOKUP(); +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Register_CONCATENATION(); +TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_MUL(); +TfLiteRegistration* Register_L2_NORMALIZATION(); +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); +TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_RESHAPE(); +TfLiteRegistration* Register_RESIZE_BILINEAR(); +TfLiteRegistration* Register_SKIP_GRAM(); +TfLiteRegistration* Register_SPACE_TO_DEPTH(); + +BuiltinOpResolver::BuiltinOpResolver() { + AddBuiltin(BuiltinOperator_RELU, Register_RELU()); + AddBuiltin(BuiltinOperator_RELU1, Register_RELU1()); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); + AddBuiltin(BuiltinOperator_TANH, Register_TANH()); + AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); + AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D()); + AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D()); + AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); + AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); + AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); + AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + Register_EMBEDDING_LOOKUP_SPARSE()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); + AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); + AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); + AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); + AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_MUL, Register_MUL()); + AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + Register_LOCAL_RESPONSE_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); + AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); + AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); +} + +TfLiteRegistration* BuiltinOpResolver::FindOp( + tflite::BuiltinOperator op) const { + auto it = builtins_.find(op); + return it != builtins_.end() ? it->second : nullptr; +} + +TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const { + auto it = custom_ops_.find(op); + return it != custom_ops_.end() ? it->second : nullptr; +} + +void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration) { + registration->builtin_code = op; + builtins_.insert(std::make_pair(op, registration)); +} + +void BuiltinOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration) { + registration->builtin_code = BuiltinOperator_CUSTOM; + custom_ops_.insert(std::make_pair(std::string(name), registration)); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h new file mode 100644 index 0000000000..28f5e0fcc8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace builtin { + +class BuiltinOpResolver : public OpResolver { + public: + BuiltinOpResolver(); + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; + TfLiteRegistration* FindOp(const char* op) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); + void AddCustom(const char* name, TfLiteRegistration* registration); + + private: + struct BuiltinOperatorHasher { + size_t operator()(const tflite::BuiltinOperator& x) const { + return std::hash()(static_cast(x)); + } + }; + std::unordered_map + builtins_; + std::unordered_map custom_ops_; +}; + +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc new file mode 100644 index 0000000000..f3e6ddc9f4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace reshape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // TODO(ahentz): we are often given a tensor with the shape but we only pay + // attention to what the shape specified in 'params'. + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Tensorflow's Reshape allows one of the shape components to have the + // special -1 value, meaning it will be calculated automatically based on the + // input. Here we calculate what that dimension should be so that the number + // of output elements in the same as the number of input elements. + int num_input_elements = 1; + for (int i = 0; i < NumDimensions(input); ++i) { + num_input_elements *= SizeOfDimension(input, i); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); + int num_output_elements = 1; + int strech_dim = -1; + for (int i = 0; i < params->num_dimensions; ++i) { + int value = params->shape[i]; + if (value == -1) { + TF_LITE_ENSURE_EQ(context, strech_dim, -1); + strech_dim = i; + } else { + num_output_elements *= value; + output_size->data[i] = value; + } + } + if (strech_dim != -1) { + output_size->data[strech_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_size->data[strech_dim]; + } + + TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + memcpy(output->data.raw, input->data.raw, input->bytes); + + return kTfLiteOk; +} + +} // namespace reshape + +TfLiteRegistration* Register_RESHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare, + reshape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc new file mode 100644 index 0000000000..59ce7d5648 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ReshapeOpTest, MismatchedDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}), + "num_input_elements != num_output_elements"); +} + +TEST(ReshapeOpTest, TooManyDimensions) { + EXPECT_DEATH( + ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "Found too many dimensions"); +} + +TEST(ReshapeOpTest, TooManySpecialDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}), + "strech_dim != -1"); +} + +TEST(ReshapeOpTest, SimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +TEST(ReshapeOpTest, WithStretchDimension) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc new file mode 100644 index 0000000000..1613c9a89f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace resize_bilinear { + +// This file has three implementation of RESIZE_BILINEAR. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = params->new_height; + output_size->data[2] = params->new_width; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // We have to fake a tensor here, to satisfy ResizeBilinear(). + int32 output_size_data[2] = {params->new_height, params->new_width}; + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + output_size_data, GetTensorDims({1, 1, 1, 2}), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops); + } +#undef TF_LITE_RESIZE_BILINEAR + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace resize_bilinear + +TfLiteRegistration* Register_RESIZE_BILINEAR_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR() { +#ifdef USE_NEON + return Register_RESIZE_BILINEAR_NEON_OPT(); +#else + return Register_RESIZE_BILINEAR_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc new file mode 100644 index 0000000000..0257c0b557 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ResizeBilinearOpModel : public SingleOpModel { + public: + ResizeBilinearOpModel(std::initializer_list input_shape, int new_height, + int new_width) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_, new_height, new_width).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(ResizeBilinearOpTest, HorizontalResize) { + ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3); + m.SetInput({3, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize) { + ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1); + m.SetInput({3, 9}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { + ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc new file mode 100644 index 0000000000..c90a15b3a2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generate a list of skip grams from an input. +// +// Options: +// ngram_size: num of words for each output item. +// max_skip_size: max num of words to skip. +// The op generates ngrams when it is 0. +// include_all_ngrams: include all ngrams with size up to ngram_size. +// +// Input: +// A string tensor to generate n-grams. +// Dim = {1} +// +// Output: +// A list of strings, each of which contains ngram_size words. +// Dim = {num_ngram} + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString); + TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString); + return kTfLiteOk; +} + +bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) { + if (size <= 0) { + return false; + } + if (params->include_all_ngrams) { + return size <= params->ngram_size; + } else { + return size == params->ngram_size; + } +} + +bool ShouldStepInRecursion(const TfLiteSkipGramParams* params, + const std::vector& stack, int stack_idx, + int num_words) { + // If current stack size and next word enumeration are within valid range. + if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) { + // If this stack is empty, step in for first word enumeration. + if (stack_idx == 0) { + return true; + } + // If next word enumeration are within the range of max_skip_size. + // NOTE: equivalent to + // next_word_idx = stack[stack_idx] + 1 + // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1 + if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) { + return true; + } + } + return false; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // Split sentence to words. + std::vector words; + tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0); + int prev_idx = 0; + for (int i = 1; i < strref.len; i++) { + if (isspace(*(strref.str + i))) { + if (i > prev_idx && !isspace(*(strref.str + prev_idx))) { + words.push_back({strref.str + prev_idx, i - prev_idx}); + } + prev_idx = i + 1; + } + } + if (strref.len > prev_idx) { + words.push_back({strref.str + prev_idx, strref.len - prev_idx}); + } + + // Generate n-grams recursively. + tflite::DynamicBuffer buf; + if (words.size() < params->ngram_size) { + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; + } + + // Stack stores the index of word used to generate ngram. + // The size of stack is the size of ngram. + std::vector stack(params->ngram_size, 0); + // Stack index that indicates which depth the recursion is operating at. + int stack_idx = 1; + int num_words = words.size(); + + while (stack_idx >= 0) { + if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) { + // When current depth can fill with a new word + // and the new word is within the max range to skip, + // fill this word to stack, recurse into next depth. + stack[stack_idx]++; + stack_idx++; + if (stack_idx < params->ngram_size) { + stack[stack_idx] = stack[stack_idx - 1]; + } + } else { + if (ShouldIncludeCurrentNgram(params, stack_idx)) { + // Add n-gram to tensor buffer when the stack has filled with enough + // words to generate the ngram. + std::vector gram(stack_idx); + for (int i = 0; i < stack_idx; i++) { + gram[i] = words[stack[i]]; + } + buf.AddJoinedString(gram, ' '); + } + // When current depth cannot fill with a valid new word, + // and not in last depth to generate ngram, + // step back to previous depth to iterate to next possible word. + stack_idx--; + } + } + + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_SKIP_GRAM() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc new file mode 100644 index 0000000000..e7f6bc904b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -0,0 +1,257 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!"; + +class SkipGramOp : public SingleOpModel { + public: + SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions, + CreateSkipGramOptions(builder_, ngram_size, max_skip_size, + include_all_ngrams) + .Union()); + BuildInterpreter({{1}}); + } + void SetInput(const string& content) { + PopulateStringTensor(input_, {content}); + } + + std::vector GetOutput() { + std::vector ans; + TfLiteTensor* tensor = interpreter_->tensor(output_); + + int num = GetStringCount(tensor); + for (int i = 0; i < num; i++) { + StringRef strref = GetString(tensor, i); + ans.push_back(string(strref.str, strref.len)); + } + return ans; + } + + private: + int input_; + int output_; +}; + +TEST(SkipGramTest, TestUnigram) { + SkipGramOp m(1, 0, false); + + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray( + {"The", "quick", "brown", "fox", "jumps", + "over", "the", "lazy", "dog!"})); +} + +TEST(SkipGramTest, TestBigram) { + SkipGramOp m(2, 0, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllBigram) { + SkipGramOp m(2, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllTrigram) { + SkipGramOp m(3, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!", + // Trigram + "The quick brown", "quick brown fox", "brown fox jumps", + "fox jumps over", "jumps over the", "over the lazy", + "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Bigram) { + SkipGramOp m(2, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "quick brown", "quick fox", "brown fox", + "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the", + "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Bigram) { + SkipGramOp m(2, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "The fox", "quick brown", + "quick fox", "quick jumps", "brown fox", "brown jumps", + "brown over", "fox jumps", "fox over", "fox the", + "jumps over", "jumps the", "jumps lazy", "over the", + "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Trigram) { + SkipGramOp m(3, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The brown fox", + "The brown jumps", "quick brown fox", "quick brown jumps", + "quick fox jumps", "quick fox over", "brown fox jumps", + "brown fox over", "brown jumps over", "brown jumps the", + "fox jumps over", "fox jumps the", "fox over the", + "fox over lazy", "jumps over the", "jumps over lazy", + "jumps the lazy", "jumps the dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Trigram) { + SkipGramOp m(3, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", + "quick brown fox", "quick brown jumps", "quick brown over", + "quick fox jumps", "quick fox over", "quick fox the", + "quick jumps over", "quick jumps the", "quick jumps lazy", + "brown fox jumps", "brown fox over", "brown fox the", + "brown jumps over", "brown jumps the", "brown jumps lazy", + "brown over the", "brown over lazy", "brown over dog!", + "fox jumps over", "fox jumps the", "fox jumps lazy", + "fox over the", "fox over lazy", "fox over dog!", + "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestAllSkip2Trigram) { + SkipGramOp m(3, 2, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", + "dog!", + // Bigram + "The quick", "The brown", "The fox", "quick brown", "quick fox", + "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps", + "fox over", "fox the", "jumps over", "jumps the", "jumps lazy", + "over the", "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!", + // Trigram + "The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", "quick brown fox", + "quick brown jumps", "quick brown over", "quick fox jumps", + "quick fox over", "quick fox the", "quick jumps over", + "quick jumps the", "quick jumps lazy", "brown fox jumps", + "brown fox over", "brown fox the", "brown jumps over", + "brown jumps the", "brown jumps lazy", "brown over the", + "brown over lazy", "brown over dog!", "fox jumps over", + "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy", + "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSingleWord) { + SkipGramOp m(1, 1, false); + m.SetInput("Hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hi")); +} + +TEST(SkipGramTest, TestWordsLessThanGram) { + SkipGramOp m(3, 1, false); + m.SetInput("Hi hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), std::vector()); +} + +TEST(SkipGramTest, TestEmptyInput) { + SkipGramOp m(1, 1, false); + m.SetInput(""); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestWhitespaceInput) { + SkipGramOp m(1, 1, false); + m.SetInput(" "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestInputWithExtraSpace) { + SkipGramOp m(1, 1, false); + m.SetInput(" Hello world ! "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!")); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc new file mode 100644 index 0000000000..ec8ec03b0d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SOFTMAX op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(SoftmaxOpTest, SimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 1.0; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 0.5; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc new file mode 100644 index 0000000000..cb2e509c98 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace space_to_depth { + +// This file has two implementation of SpaceToDepth. Note that SpaceToDepth +// only works on 4D tensors. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + auto data_type = output->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || + data_type == kTfLiteInt32 || data_type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const int block_size = params->block_size; + const int input_height = input->dims->data[1]; + const int input_width = input->dims->data[2]; + int output_height = input_height / block_size; + int output_width = input_width / block_size; + + TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size); + TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = input->dims->data[3] * block_size * block_size; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + +#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ + type::SpaceToDepth( \ + GetTensorData(input), GetTensorDims(input), params->block_size, \ + GetTensorData(output), GetTensorDims(output)) + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, float); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +#undef TF_LITE_SPACE_TO_DEPTH + + return kTfLiteOk; +} + +} // namespace space_to_depth + +TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH() { + return Register_SPACE_TO_DEPTH_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc new file mode 100644 index 0000000000..911f08a92c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class SpaceToDepthOpModel : public SingleOpModel { + public: + SpaceToDepthOpModel(const TensorData& tensor_data, int block_size) { + input_ = AddInput(tensor_data); + output_ = AddOutput(tensor_data); + SetBuiltinOp(BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOptions_SpaceToDepthOptions, + CreateSpaceToDepthOptions(builder_, block_size).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(SpaceToDepthOpModel, BadBlockSize) { + EXPECT_DEATH(SpaceToDepthOpModel({TensorType_FLOAT32, {1, 2, 2, 1}}, 3), + "Cannot allocate tensors"); +} + +TEST(SpaceToDepthOpModel, Float32) { + SpaceToDepthOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, 2); + m.SetInput({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 8)); +} + +TEST(SpaceToDepthOpModel, Uint8) { + SpaceToDepthOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, 2); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(SpaceToDepthOpModel, Int32) { + SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 12)); +} + +TEST(SpaceToDepthOpModel, Int64) { + SpaceToDepthOpModel m({TensorType_INT64, {1, 4, 4, 1}}, 2); + m.SetInput({1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc new file mode 100644 index 0000000000..dd414d53bd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -0,0 +1,224 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace svdf { + +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int kStateTensor = 0; +constexpr int KOutputTensor = 1; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, 1, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + int* scratch_tensor_index = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ASSERT_EQ(num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]); + TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters); + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + + // Resize state. + // For each batch, the state is a 2-D tensor: memory_size * num_filters + // The left most column is used to save current cycle activation. + // The right most column is used to save temporary output which will be + // reduced to num_units outputs. + TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); + state_size_array->data[0] = batch_size; + state_size_array->data[1] = memory_size * num_filters; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, state, state_size_array)); + + // Mark state as a persistent tensor. + state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + + // Resize scratch. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = *scratch_tensor_index; + + TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); + scratch_size_array->data[0] = batch_size; + scratch_size_array->data[1] = num_filters; + + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = input->type; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, + scratch_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]]; + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int c = 0; c < num_filters; c++) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } + } + + // Compute conv1d(inputs, weights_feature). + // The state left most column is used to save current cycle activation. This + // is achieved by starting at state->data.f[memory_size - 1] and having the + // stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature->data.f, num_filters, input_size, input->data.f, + batch_size, &state->data.f[memory_size - 1], memory_size); + + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum + // TODO(ghodrat): Consider not reusing state for the temporary output, this + // way ReductionSum operates on row-vector instead of column vector. + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + params->activation, output_ptr_batch); + } + + // Right shift the state. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; f++) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } + return kTfLiteOk; +} + +} // namespace svdf + +TfLiteRegistration* Register_SVDF() { + static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare, + svdf::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc new file mode 100644 index 0000000000..d956025e9d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -0,0 +1,312 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SVDF op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +// Derived class of SingleOpModel, which is used to test SVDF TFLite op. +class SVDFOpModel : public SingleOpModel { + public: + SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(TensorType_FLOAT32); + weights_time_ = AddInput(TensorType_FLOAT32); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +TEST(SVDFOpTest, BlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(SVDFOpTest, BlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc new file mode 100644 index 0000000000..f716ba8741 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/test_util.h" + +#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +using ::testing::FloatNear; +using ::testing::Matcher; + +namespace { +template +std::pair QuantizationParams(float f_min, float f_max) { + // These are required by many quantized operations. + CHECK_LE(f_min, 0); + CHECK_GE(f_max, 0); + T q_min = std::numeric_limits::min(); + T q_max = std::numeric_limits::max(); + float range = q_max - q_min; + float scale = (f_max - f_min) / range; + int32_t zero_point = std::min( + q_max, + std::max(q_min, static_cast(std::round(q_min - f_min / scale)))); + return {scale, zero_point}; +} +} // namespace + +std::vector> ArrayFloatNear(const std::vector& values, + float max_abs_error) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; +} + +int SingleOpModel::AddTensor(TensorData t) { + int id = tensors_.size(); + + // This is slightly different depending on whether we are adding a + // quantized or a regular tensor. + bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0); + + flatbuffers::Offset q_params = 0; + + if (is_quantized) { + if (t.min != 0 || t.max != 0) { + if (t.type == TensorType_UINT8) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT32) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else { + LOG(FATAL) << "No support for the requested quantized type"; + } + t.min = 0; + t.max = 0; + } + + q_params = CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector({t.scale}), + builder_.CreateVector({t.zero_point})); + } + + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), + t.type, /*buffer=*/0, + /*name=*/0, q_params)); + + tensor_data_[id] = t; + + return id; +} + +int SingleOpModel::AddInput(const TensorData& t) { + int id = AddTensor(t); + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddNullInput() { + int id = kOptionalTensor; + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddOutput(const TensorData& t) { + int id = AddTensor(t); + outputs_.push_back(id); + return id; +} + +void SingleOpModel::SetBuiltinOp(BuiltinOperator type, + BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options) { + opcodes_.push_back(CreateOperatorCode(builder_, type, 0)); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), builtin_options_type, + builtin_options, + /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::SetCustomOp( + const string& name, const std::vector& custom_option, + const std::function& registeration) { + custom_registrations_[name] = registeration; + opcodes_.push_back( + CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data())); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), BuiltinOptions_NONE, 0, + builder_.CreateVector(custom_option), + CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::BuildInterpreter( + std::vector> input_shapes) { + auto opcodes = builder_.CreateVector(opcodes_); + auto operators = builder_.CreateVector(operators_); + auto tensors = builder_.CreateVector(tensors_); + auto inputs = builder_.CreateVector(inputs_); + auto outputs = builder_.CreateVector(outputs_); + // Create a single subgraph + std::vector> subgraphs; + auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators); + subgraphs.push_back(subgraph); + auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs); + + std::vector> buffers_vec; + auto buffers = builder_.CreateVector(buffers_vec); + auto description = builder_.CreateString("programmatic model"); + builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs_flatbuffer, description, buffers)); + + auto* model = GetModel(builder_.GetBufferPointer()); + + ops::builtin::BuiltinOpResolver builtins; + for (const auto& reg : custom_registrations_) { + builtins.AddCustom(reg.first.data(), reg.second()); + } + InterpreterBuilder(model, builtins)(&interpreter_); + + CHECK(interpreter_ != nullptr); + + int i = 0; + for (const auto& shape : input_shapes) { + int input_idx = interpreter_->inputs()[i++]; + if (input_idx == kOptionalTensor) continue; + CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; +} + +void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } + +int32_t SingleOpModel::GetTensorSize(int index) const { + TfLiteTensor* t = interpreter_->tensor(index); + CHECK(t); + int total_size = 1; + for (int i = 0; i < t->dims->size; ++i) { + total_size *= t->dims->data[i]; + } + return total_size; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h new file mode 100644 index 0000000000..e68e494661 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -0,0 +1,202 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ + +#include + +#include +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +inline void LogToStderr() { +#ifdef PLATFORM_GOOGLE + FLAGS_logtostderr = true; +#endif +} + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector<::testing::Matcher> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5); + +template +inline std::vector Quantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector q; + for (float f : data) { + q.push_back(std::max( + std::numeric_limits::min(), + std::min(std::numeric_limits::max(), + static_cast(std::round(zero_point + (f / scale)))))); + } + return q; +} + +template +inline std::vector Dequantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector f; + for (T q : data) { + f.push_back(scale * (q - zero_point)); + } + return f; +} + +// A test model that contains a single operator. All operator inputs and +// output are external to the model, so the tests can directly access them. +// Typical usage: +// SingleOpModel m; +// int a = m.AddInput({TensorType_FLOAT32, a_shape}); +// int b = m.AddInput({TensorType_FLOAT32, b_shape}); +// int c = m.AddOutput({TensorType_FLOAT32, {}}); +// m.SetBuiltinOp(...); +// m.BuildInterpreter({GetShape(a), GetShape(b)}); +// m.PopulateTensor(a, {...}); +// m.PopulateTensor(b, {...}); +// m.Invoke(); +// EXPECT_THAT(m.ExtractVector(c), ArrayFloatNear({...})); +// + +// A helper struct to construct test tensors. This is particularly useful for +// quantized tensor which must have their scale and zero_point defined before +// the actual data is known. This mimics what happens in practice: quantization +// parameters are calculate during training. +struct TensorData { + TensorType type; + std::vector shape; + float min; + float max; + float scale; + int32_t zero_point; +}; + +class SingleOpModel { + public: + SingleOpModel() {} + ~SingleOpModel() {} + + // Copying or assignment is disallowed to simplify ownership semantics. + SingleOpModel(const SingleOpModel&) = delete; + SingleOpModel& operator=(const SingleOpModel&) = delete; + + // Add a TensorType input tensor and return its index. + int AddInput(TensorType type) { return AddInput(TensorData{type}); } + int AddInput(const TensorData& t); + + // Add a null input tensor (optional input) and return kOptionalTensor. + int AddNullInput(); + + // Add a TensorType output tensor and return its index. + int AddOutput(TensorType type) { return AddOutput(TensorData{type}); } + int AddOutput(const TensorData& t); + + template + void QuantizeAndPopulate(int index, std::initializer_list data) { + TfLiteTensor* t = interpreter_->tensor(index); + auto q = Quantize(data, t->params.scale, t->params.zero_point); + PopulateTensor(index, 0, q.data(), q.data() + q.size()); + } + + const std::vector& GetShape(int id) { return tensor_data_.at(id).shape; } + + float GetScale(int id) { return tensor_data_.at(id).scale; } + int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; } + + // Define the operator in this model. + void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options); + void SetCustomOp(const string& name, + const std::vector& custom_option, + const std::function& registeration); + + // Build the interpreter for this model. Also, resize and allocate all + // tensors given the shapes of the inputs. + void BuildInterpreter(std::vector> input_shapes); + + void Invoke(); + + void PopulateStringTensor(int index, const std::vector& content) { + auto tensor = interpreter_->tensor(index); + DynamicBuffer buf; + for (const string& s : content) { + buf.AddString(s.data(), s.length()); + } + buf.WriteToTensor(tensor); + } + + // Populate the tensor given its index. + template + void PopulateTensor(int index, std::initializer_list data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Partially populate the tensor, starting at the given offset. + template + void PopulateTensor(int index, int offset, T* begin, T* end) { + T* v = interpreter_->typed_tensor(index); + memcpy(v + offset, begin, (end - begin) * sizeof(T)); + } + + // Return a vector with the flattened contents of a tensor. + template + std::vector ExtractVector(int index) { + T* v = interpreter_->typed_tensor(index); + CHECK(v); + return std::vector(v, v + GetTensorSize(index)); + } + + std::vector GetTensorShape(int index) { + std::vector result; + TfLiteTensor* t = interpreter_->tensor(index); + for (int i = 0; i < t->dims->size; ++i) { + result.push_back(t->dims->data[i]); + } + return result; + } + + protected: + int32_t GetTensorSize(int index) const; + + flatbuffers::FlatBufferBuilder builder_; + std::unique_ptr interpreter_; + + private: + int AddTensor(TensorData t); + + std::map tensor_data_; + std::vector inputs_; + std::vector outputs_; + std::vector> tensors_; + std::vector> opcodes_; + std::vector> operators_; + std::map> custom_registrations_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc new file mode 100644 index 0000000000..f8208f6f98 --- /dev/null +++ b/tensorflow/contrib/lite/model.cc @@ -0,0 +1,673 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +const char* kEmptyTensorName = ""; + +std::unique_ptr FlatBufferModel::BuildFromFile( + const char* filename, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, + /*use_nnapi=*/true)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::BuildFromBuffer( + const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, + ErrorReporter* error_reporter, bool use_nnapi) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation_ = new NNAPIAllocation(filename, error_reporter); + else + allocation_ = new MMAPAllocation(filename, error_reporter); + } else { + allocation_ = new FileCopyAllocation(filename, error_reporter); + } + if (!allocation_->valid()) return; + if (!CheckModelIdentifier()) return; + + model_ = ::tflite::GetModel(allocation_->base()); +} + +bool FlatBufferModel::CheckModelIdentifier() const { + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + error_reporter_->Report( + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; +} + +FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); + if (!allocation_->valid()) return; + model_ = ::tflite::GetModel(allocation_->base()); +} + +FlatBufferModel::~FlatBufferModel() { delete allocation_; } + +InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver) + : model_(model.GetModel()), + op_resolver_(op_resolver), + error_reporter_(model.error_reporter()), + allocation_(model.allocation()) {} + +InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) {} + +TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (const OperatorCode* opcode : *opcodes) { + TfLiteRegistration* registration = nullptr; + + if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { + auto x = opcode->builtin_code(); + flatbuffer_op_index_to_registration_types_.push_back(x); + registration = op_resolver_.FindOp(x); + if (registration == nullptr) { + error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", + EnumNameBuiltinOperator(x)); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter_->Report( + "Operator with builtin_code==0 has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + registration = op_resolver_.FindOp(name); + flatbuffer_op_index_to_registration_types_.push_back( + BuiltinOperator_CUSTOM); + if (registration == nullptr) { + error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + status = kTfLiteError; + } + } + flatbuffer_op_index_to_registration_.push_back(registration); + } + return status; +} + +namespace { +template +std::vector FlatBufferIntArrayToVector(T* flat_array) { + std::vector ret(flat_array->Length()); + for (int i = 0; i < flat_array->Length(); i++) { + ret[i] = flat_array->Get(i); + } + return ret; +} + +// Allocate a structure using C malloc, but make sure the structure is a +// POD structure that doesn't require constructors to run. The reason we do +// this, is that Interpreter's C extension part will take ownership and wants +// to use malloc() and free(). +template +T* MallocPOD() { + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + return static_cast(malloc(sizeof(T))); +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// +// Returns memory that must be feed. +void* ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + void* builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CALL: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + break; + case BuiltinOperator_CUSTOM: + break; + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_TANH: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU1: + case BuiltinOperator_RELU6: + case BuiltinOperator_CONCAT_EMBEDDINGS: + break; + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP: + // no-op. + break; + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->new_height = schema_params->new_height(); + params->new_width = schema_params->new_width(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + if (!new_shape) { + error_reporter->Report("No new_shape provided for Reshape\n"); + } else { + params->num_dimensions = new_shape->Length(); + if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in Reshape's new_shape\n"); + } else { + for (int i = 0; i < params->num_dimensions; ++i) { + params->shape[i] = new_shape->Get(i); + } + } + } + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + builtin_data = reinterpret_cast(params); + break; + } + } + return builtin_data; +} + +} // namespace + +TfLiteStatus InterpreterBuilder::ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + for (int i = 0; i < operators->Length(); ++i) { + const auto* op = operators->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + status = kTfLiteError; + continue; + } + const TfLiteRegistration* reg = + flatbuffer_op_index_to_registration_[op->opcode_index()]; + if (reg == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + status = kTfLiteError; + continue; + } + + auto op_type = + flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + if (op->custom_options()) { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + reinterpret_cast(op->custom_options()->data()), + op->custom_options()->size(), nullptr, reg); + } else { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, + ParseOpData(op, op_type, error_reporter_), reg); + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + + // A little helper to get the names of inputs and outputs. Note that they + // must outlive the interpreter. + auto get_name = [](const tflite::Tensor* t) -> const char* { + auto name = t->name(); + if (name) return name->c_str(); + return kEmptyTensorName; + }; + + for (int i = 0; i < tensors->Length(); ++i) { + const auto* tensor = tensors->Get(i); + std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); + + TfLiteQuantizationParams quantization; + quantization.scale = 0; + quantization.zero_point = 0; + auto* q_params = tensor->quantization(); + if (q_params) { + // Note that the schema could hold per-channel quantization parameters + // but we really only support one value for the whole tensor. + // TODO(aselle): This breaks as well if these are nullptr's. + // TODO(aselle): This assumes non per-channel quantization. + if (q_params->scale()) quantization.scale = q_params->scale()->Get(0); + if (q_params->zero_point()) + quantization.zero_point = q_params->zero_point()->Get(0); + } + + TfLiteType type; + switch (tensor->type()) { + case TensorType_FLOAT32: + type = kTfLiteFloat32; + break; + case TensorType_INT32: + type = kTfLiteInt32; + break; + case TensorType_UINT8: + type = kTfLiteUInt8; + break; + case TensorType_INT64: + type = kTfLiteInt64; + break; + case TensorType_STRING: + type = kTfLiteString; + break; + default: + // tensorType = ArrayType::NONE; + error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor->type()), + tensor->type()); + status = kTfLiteError; + continue; + } + auto get_readonly_data = [&](const char** buffer_data, + size_t* buffer_size) { + // TODO(aselle): Check what happens if we have an unspecified size + // constant. + *buffer_data = nullptr; + if (tensor->buffer() == 0) return kTfLiteOk; + if (tensor->buffer() >= buffers->size()) { + error_reporter_->Report( + "Tensor %d specifies out of range buffer %d (only %d buffers).\n", + i, tensor->buffer(), buffers->size()); + return kTfLiteError; + } + if (auto* buffer = (*buffers)[tensor->buffer()]) { + if (auto* array = buffer->data()) { + if (size_t size = array->size()) { + *buffer_size = size; + *buffer_data = reinterpret_cast(array->data()); + return kTfLiteOk; + } + } + } + return kTfLiteOk; + }; + size_t buffer_size = 0; + const char* buffer_ptr; + TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + + if (buffer_ptr) { + if (interpreter->SetTensorParametersReadOnly( + i, type, get_name(tensor), dims, quantization, buffer_ptr, + buffer_size, allocation_) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } else { + if (interpreter->SetTensorParametersReadWrite( + i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr* interpreter) { + if (!interpreter) { + error_reporter_->Report( + "Null output pointer passed to InterpreterBuilder."); + return kTfLiteError; + } + + // Safe exit by deleting partially created interpreter, to reduce verbosity + // on error conditions. Use by return cleanup_on_error(); + auto cleanup_and_error = [&interpreter]() { + interpreter->reset(); + return kTfLiteError; + }; + + if (!model_) { + error_reporter_->Report("Null pointer passed in as model."); + return cleanup_and_error(); + } + + if (model_->version() != TFLITE_SCHEMA_VERSION) { + error_reporter_->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model_->version(), TFLITE_SCHEMA_VERSION); + return cleanup_and_error(); + } + + if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { + error_reporter_->Report("Registration failed.\n"); + return cleanup_and_error(); + } + + // Flatbuffer model schemas define a list of opcodes independent of the graph. + // We first map those to registrations. This reduces string lookups for custom + // ops since we only do it once per custom op rather than once per custom op + // invocation in the model graph. + // Construct interpreter with correct number of tensors and operators. + auto* subgraphs = model_->subgraphs(); + auto* buffers = model_->buffers(); + if (subgraphs->size() != 1) { + error_reporter_->Report("Only 1 subgraph is currently supported.\n"); + return cleanup_and_error(); + } + const tflite::SubGraph* subgraph = (*subgraphs)[0]; + auto operators = subgraph->operators(); + auto tensors = subgraph->tensors(); + if (!operators || !tensors || !buffers) { + error_reporter_->Report( + "Did not get operators, tensors, or buffers in input flat buffer.\n"); + return cleanup_and_error(); + } + interpreter->reset(new Interpreter(error_reporter_)); + if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { + return cleanup_and_error(); + } + + // Parse inputs/outputs + (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); + (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); + + // Finally setup nodes and tensors + if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h new file mode 100644 index 0000000000..15659d33f3 --- /dev/null +++ b/tensorflow/contrib/lite/model.h @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Deserialization infrastructure for tflite. Provides functionality +// to go from a serialized tflite model in flatbuffer format to an +// interpreter. +// +// using namespace tflite; +// StderrReporter error_reporter; +// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", +// &error_reporter); +// MyOpResolver resolver; // You need to subclass OpResolver to provide +// // implementations. +// InterpreterBuilder builder(*model, resolver); +// std::unique_ptr interpreter; +// if(builder(&interpreter) == kTfLiteOk) { +// .. run model inference with interpreter +// } +// +// OpResolver must be defined to provide your kernel implementations to the +// interpreter. This is environment specific and may consist of just the builtin +// ops, or some custom operators you defined to extend tflite. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// An RAII object that represents a read-only tflite model, copied from disk, +// or mmapped. This uses flatbuffers as the serialization format. +class FlatBufferModel { + public: + // Build a model based on a file. Return a nullptr in case of failure. + static std::unique_ptr BuildFromFile( + const char* filename, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Build a model based on a pre-loaded flatbuffer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Return a nullptr in case of failure. + static std::unique_ptr BuildFromBuffer( + const char* buffer, size_t buffer_size, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Releases memory or unmaps mmaped meory. + ~FlatBufferModel(); + + // Copying or assignment is disallowed to simplify ownership semantics. + FlatBufferModel(const FlatBufferModel&) = delete; + FlatBufferModel& operator=(const FlatBufferModel&) = delete; + + bool initialized() const { return model_ != nullptr; } + const tflite::Model* operator->() const { return model_; } + const tflite::Model* GetModel() const { return model_; } + ErrorReporter* error_reporter() const { return error_reporter_; } + const Allocation* allocation() const { return allocation_; } + + // Returns true if the model identifier is correct (otherwise false and + // reports an error). + bool CheckModelIdentifier() const; + + private: + // Load a model from `filename`. If `mmap_file` is true then use mmap, + // otherwise make a copy of the model in a buffer. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + explicit FlatBufferModel( + const char* filename, bool mmap_file = true, + ErrorReporter* error_reporter = DefaultErrorReporter(), + bool use_nnapi = false); + + // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to + // remain alive and unchanged until the end of this flatbuffermodel's + // lifetime. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Flatbuffer traverser pointer. (Model* is a pointer that is within the + // allocated memory of the data allocated by allocation's internals. + const tflite::Model* model_ = nullptr; + ErrorReporter* error_reporter_; + Allocation* allocation_ = nullptr; +}; + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Find the op registration for a builtin operator by enum code. + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + // Find the op registration of a custom operator by op name. + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual ~OpResolver() {} +}; + +// Build an interpreter capable of interpreting `model`. +// +// model: a scoped model whose lifetime must be at least as long as +// the interpreter. In principle multiple interpreters can be made from +// a single model. +// op_resolver: An instance that implements the Resolver interface which maps +// custom op names and builtin op codes to op registrations. +// reportError: a functor that is called to report errors that handles +// printf var arg semantics. The lifetime of the reportError object must +// be greater than or equal to the Interpreter created by operator(). +// +// Returns a kTfLiteOk when successful and sets interpreter to a valid +// Interpreter. Note: the user must ensure the model lifetime is at least as +// long as interpreter's lifetime. +class InterpreterBuilder { + public: + InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver); + // Build an interpreter given only the raw flatbuffer Model object (instead + // of a FlatBufferModel). Mostly used for testing. + // If `error_reporter` is null, then DefaultErrorReporter() is used. + InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter = DefaultErrorReporter()); + InterpreterBuilder(const InterpreterBuilder&) = delete; + InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; + TfLiteStatus operator()(std::unique_ptr* interpreter); + + private: + TfLiteStatus BuildLocalIndexToRegistrationMapping(); + TfLiteStatus ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter); + TfLiteStatus ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter); + + const ::tflite::Model* model_; + const OpResolver& op_resolver_; + ErrorReporter* error_reporter_; + + std::vector flatbuffer_op_index_to_registration_; + std::vector flatbuffer_op_index_to_registration_types_; + const Allocation* allocation_ = nullptr; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc new file mode 100644 index 0000000000..ae823650d6 --- /dev/null +++ b/tensorflow/contrib/lite/model_test.cc @@ -0,0 +1,258 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" + +#include +#include "tensorflow/contrib/lite/error_reporter.h" + +// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, +// we must declare this in global namespace, so argument-dependent operator +// lookup works. +inline bool operator==(const TfLiteRegistration& a, + const TfLiteRegistration& b) { + return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare && + a.free == b.free; +} + +namespace tflite { + +// Provide a dummy operation that does nothing. +namespace { +void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; } +void dummy_free(TfLiteContext*, void*) {} +TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize, + dummy_invoke}; +} // namespace + +// Provide a trivial resolver that returns a constant value no matter what +// op is asked for. +class TrivialResolver : public OpResolver { + public: + explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr) + : constant_return_(constant_return) {} + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + return constant_return_; + } + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(const char* op) const override { + return constant_return_; + } + + private: + TfLiteRegistration* constant_return_; +}; + +TEST(BasicFlatBufferModel, TestNonExistantFiles) { + ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); +} + +// Make sure a model with nothing in it loads properly. +TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin"); + ASSERT_TRUE(model); + // Now try to build it into a model. + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk); +} + +// Make sure currently unsupported # of subgraphs are checked +// TODO(aselle): Replace this test when multiple subgraphs are supported. +TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { + auto m1 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + ASSERT_TRUE(m1); + std::unique_ptr interpreter1; + ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), + kTfLiteOk); + + auto m2 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/2_subgraphs.bin"); + ASSERT_TRUE(m2); + std::unique_ptr interpreter2; + ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), + kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter, nullptr); +} + +// Make sure model is read to interpreter propelrly +TEST(BasicFlatBufferModel, TestModelInInterpreter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(interpreter->tensors_size(), 4); + ASSERT_EQ(interpreter->nodes_size(), 2); + std::vector inputs = {0, 1}; + std::vector outputs = {2, 3}; + ASSERT_EQ(interpreter->inputs(), inputs); + ASSERT_EQ(interpreter->outputs(), outputs); + + EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0"); + EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2"); + + // Make sure all input tensors are correct + TfLiteTensor* i0 = interpreter->tensor(0); + ASSERT_EQ(i0->type, kTfLiteFloat32); + ASSERT_NE(i0->data.raw, nullptr); // mmapped + ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo); + TfLiteTensor* i1 = interpreter->tensor(1); + ASSERT_EQ(i1->type, kTfLiteFloat32); + ASSERT_EQ(i1->data.raw, nullptr); + ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o0 = interpreter->tensor(2); + ASSERT_EQ(o0->type, kTfLiteFloat32); + ASSERT_EQ(o0->data.raw, nullptr); + ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o1 = interpreter->tensor(3); + ASSERT_EQ(o1->type, kTfLiteFloat32); + ASSERT_EQ(o1->data.raw, nullptr); + ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw); + + // Check op 0 which has inputs {0, 1} outputs {2}. + { + const std::pair* node_and_reg0 = + interpreter->node_and_registration(0); + ASSERT_NE(node_and_reg0, nullptr); + const TfLiteNode& node0 = node_and_reg0->first; + const TfLiteRegistration& reg0 = node_and_reg0->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2); + desired_inputs->data[0] = 0; + desired_inputs->data[1] = 1; + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_outputs->data[0] = 2; + ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg0, dummy_reg); + } + + // Check op 1 which has inputs {2} outputs {3}. + { + const std::pair* node_and_reg1 = + interpreter->node_and_registration(1); + ASSERT_NE(node_and_reg1, nullptr); + const TfLiteNode& node1 = node_and_reg1->first; + const TfLiteRegistration& reg1 = node_and_reg1->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1); + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_inputs->data[0] = 2; + desired_outputs->data[0] = 3; + ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg1, dummy_reg); + } +} + +// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped +// buffer. But the buffer is provided to be only 1 element. +TEST(BasicFlatBufferModel, TestBrokenMmap) { + ASSERT_FALSE(FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model_broken.bin")); +} + +TEST(BasicFlatBufferModel, TestNullModel) { + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE( + InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter.get(), nullptr); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + calls++; + return 0; + } + int calls = 0; +}; + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestCustomErrorReporter) { + TestErrorReporter reporter; + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", + &reporter); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.calls, 1); +} + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestNullErrorReporter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); +} + +// TODO(aselle): Add tests for serialization of builtin op data types. +// These tests will occur with the evaluation tests of individual operators, +// not here. + +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD new file mode 100644 index 0000000000..fbdf19f205 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc new file mode 100644 index 0000000000..1c422b659a --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Convert a list of strings to integers via hashing. +// Input: +// Input[0]: A list of ngrams. string[num of input] +// +// Output: +// Output[0]: Hashed features. int32[num of input] +// Output[1]: Weights. float[num of input] + +#include +#include +#include "re2/re2.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string_util.h" +#include + +namespace tflite { +namespace ops { +namespace custom { + +namespace extract { + +static const int kMaxDimension = 1000000; +static const std::vector kBlacklistNgram = {"", "", " "}; + +bool Equals(const string& x, const tflite::StringRef& strref) { + if (strref.len != x.length()) { + return false; + } + if (strref.len > 0) { + int r = memcmp(strref.str, x.data(), strref.len); + return r == 0; + } + return true; +} + +bool IsValidNgram(const tflite::StringRef& strref) { + for (const auto& s : kBlacklistNgram) { + if (Equals(s, strref)) { + return false; + } + } + return true; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1); + TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1); + TfLiteTensor* input = GetInput(context, node, 0); + int dim = input->dims->data[0]; + if (dim == 0) { + // TFLite non-string output should have size greater than 0. + dim = 1; + } + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString); + outputSize1->data[0] = dim; + outputSize2->data[0] = dim; + context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1); + context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + int num_strings = tflite::GetStringCount(input); + TfLiteTensor* label = GetOutput(context, node, 0); + TfLiteTensor* weight = GetOutput(context, node, 1); + + std::map feature_id_counts; + for (int i = 0; i < num_strings; i++) { + // Use fingerprint of feature name as id. + auto strref = tflite::GetString(input, i); + if (!IsValidNgram(strref)) { + label->data.i32[i] = 0; + weight->data.i32[i] = 0; + continue; + } + + int64 feature_id = + ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension; + + label->data.i32[i] = static_cast(feature_id); + weight->data.f[i] = + std::count(strref.str, strref.str + strref.len, ' ') + 1; + } + // Explicitly set an empty result to make preceding ops run. + if (num_strings == 0) { + label->data.i32[0] = 0; + weight->data.i32[0] = 0; + } + return kTfLiteOk; +} + +} // namespace extract + +TfLiteRegistration* Register_EXTRACT_FEATURES() { + static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare, + extract::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc new file mode 100644 index 0000000000..9b8676bab6 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_EXTRACT_FEATURES(); + +namespace { + +using ::testing::ElementsAre; + +class ExtractFeatureOpModel : public SingleOpModel { + public: + explicit ExtractFeatureOpModel(const std::vector& input) { + input_ = AddInput(TensorType_STRING); + signature_ = AddOutput(TensorType_INT32); + weight_ = AddOutput(TensorType_FLOAT32); + + SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES); + BuildInterpreter({{static_cast(input.size())}}); + PopulateStringTensor(input_, input); + } + + std::vector GetSignature() { return ExtractVector(signature_); } + std::vector GetWeight() { return ExtractVector(weight_); } + + private: + int input_; + int signature_; + int weight_; +}; + +int CalcFeature(const string& str) { + return ::util::Fingerprint64(str) % 1000000; +} + +TEST(ExtractFeatureOpTest, RegularInput) { + ExtractFeatureOpModel m({"", " Hi", "Hi", "Hi !", "!", "! ", ""}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), + ElementsAre(0, CalcFeature(" Hi"), CalcFeature("Hi"), + CalcFeature("Hi !"), CalcFeature("!"), + CalcFeature("! "), 0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0)); +} + +TEST(ExtractFeatureOpTest, OneInput) { + ExtractFeatureOpModel m({"Hi"}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi"))); + EXPECT_THAT(m.GetWeight(), ElementsAre(1)); +} + +TEST(ExtractFeatureOpTest, ZeroInput) { + ExtractFeatureOpModel m({}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0)); +} + +TEST(ExtractFeatureOpTest, AllBlacklistInput) { + ExtractFeatureOpModel m({"", ""}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0)); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc new file mode 100644 index 0000000000..d0dc2a35a7 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Normalize the string input. +// +// Input: +// Input[0]: One sentence. string[1] +// +// Output: +// Output[0]: Normalized sentence. string[1] +// +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" +#include "re2/re2.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { + +namespace normalize { + +// Predictor transforms. +const char kPunctuationsRegex[] = "[.*()\"]"; + +const std::map* kRegexTransforms = + new std::map({ + {"([^\\s]+)n't", "\\1 not"}, + {"([^\\s]+)'nt", "\\1 not"}, + {"([^\\s]+)'ll", "\\1 will"}, + {"([^\\s]+)'re", "\\1 are"}, + {"([^\\s]+)'ve", "\\1 have"}, + {"i'm", "i am"}, + }); + +static const char kStartToken[] = ""; +static const char kEndToken[] = ""; +static const int32 kMaxInputChars = 300; + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0); + + string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len))); + absl::StripAsciiWhitespace(&result); + // Do not remove commas, semi-colons or colons from the sentences as they can + // indicate the beginning of a new clause. + RE2::GlobalReplace(&result, kPunctuationsRegex, ""); + RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])", + "\\1\\2"); + RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1"); + for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); + iter++) { + RE2::GlobalReplace(&result, iter->first, iter->second); + } + + // Treat questions & interjections as special cases. + RE2::GlobalReplace(&result, "([?])+", "\\1"); + RE2::GlobalReplace(&result, "([!])+", "\\1"); + RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 "); + RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2"); + + RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", ""); + RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", ""); + absl::StripAsciiWhitespace(&result); + + // Add start and end token. + // Truncate input to maximum allowed size. + if (result.length() <= kMaxInputChars) { + absl::StrAppend(&result, " ", kEndToken); + } else { + result = result.substr(0, kMaxInputChars); + } + result = absl::StrCat(kStartToken, " ", result); + + tflite::DynamicBuffer buf; + buf.AddString(result.data(), result.length()); + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} + +} // namespace normalize + +TfLiteRegistration* Register_NORMALIZE() { + static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc new file mode 100644 index 0000000000..4d35dba9a6 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_NORMALIZE(); + +namespace { + +using ::testing::ElementsAreArray; + +class NormalizeOpModel : public SingleOpModel { + public: + explicit NormalizeOpModel(const string& input) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetCustomOp("Normalize", {}, Register_NORMALIZE); + BuildInterpreter({{static_cast(input.size())}}); + PopulateStringTensor(input_, {input}); + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + private: + int input_; + int output_; +}; + +TEST(NormalizeOpTest, RegularInput) { + NormalizeOpModel m("I'm good; you're welcome"); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), + ElementsAreArray({" i am good; you are welcome "})); +} + +TEST(NormalizeOpTest, OneInput) { + NormalizeOpModel m("Hi!!!!"); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({" hi ! "})); +} + +TEST(NormalizeOpTest, EmptyInput) { + NormalizeOpModel m(""); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({" "})); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc new file mode 100644 index 0000000000..7b23adb990 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc @@ -0,0 +1,174 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Lookup projected hash signatures in Predictor model, +// output predicted labels and weights in decreasing order. +// +// Input: +// Input[0]: A list of hash signatures. int32[num of input] +// Input[1]: Hash signature keys in the model. int32[keys of model] +// Input[2]: Labels in the model. int32[keys of model, item per entry] +// Input[3]: Weights in the model. float[keys of model, item per entry] +// +// Output: +// Output[0]: Predicted labels. int32[num of output] +// Output[1]: Predicted weights. float[num of output] +// + +#include +#include +#include + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +namespace predict { + +struct PredictOption { + int32_t num_output; + float weight_threshold; + + static PredictOption* Cast(void* ptr) { + return reinterpret_cast(ptr); + } +}; + +bool WeightGreater(const std::pair& a, + const std::pair& b) { + return a.second > b.second; +} + +void* Init(TfLiteContext* context, const char* custom_option, size_t length) { + if (custom_option == nullptr || length != sizeof(PredictOption)) { + fprintf(stderr, "No Custom option set\n"); + exit(1); + } + PredictOption* option = new PredictOption; + int offset = 0; + option->num_output = + *reinterpret_cast(custom_option + offset); + offset += sizeof(int32_t); + option->weight_threshold = + *reinterpret_cast(custom_option + offset); + return reinterpret_cast(option); +} + +void Free(TfLiteContext* context, void* buffer) { + delete PredictOption::Cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; + TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1); + TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1); + TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2); + TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2); + TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], + model_label->dims->data[0]); + TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], + model_weight->dims->data[0]); + TF_LITE_ENSURE_EQ(context, model_label->dims->data[1], + model_weight->dims->data[1]); + + PredictOption* option = PredictOption::Cast(node->user_data); + TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; + TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32); + + TfLiteIntArray* label_size = TfLiteIntArrayCreate(1); + label_size->data[0] = option->num_output; + TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1); + weight_size->data[0] = option->num_output; + TfLiteStatus status = + context->ResizeTensor(context, output_label, label_size); + if (status != kTfLiteOk) { + return status; + } + return context->ResizeTensor(context, output_weight, weight_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; + TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; + + // Aggregate by key + std::unordered_map aggregation; + const int num_input = lookup->dims->data[0]; + const int num_rows = model_key->dims->data[0]; + const int items = model_label->dims->data[1]; + int* model_key_end = model_key->data.i32 + num_rows; + + for (int i = 0; i < num_input; i++) { + int* ptr = std::lower_bound(model_key->data.i32, model_key_end, + lookup->data.i32[i]); + if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) { + int idx = ptr - model_key->data.i32; + for (int j = 0; j < items; j++) { + aggregation[model_label->data.i32[idx * items + j]] += + model_weight->data.f[idx * items + j] / num_input; + } + } + } + + // Sort by value + std::vector> sorted_labels(aggregation.begin(), + aggregation.end()); + std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater); + + PredictOption* option = PredictOption::Cast(node->user_data); + TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; + for (int i = 0; i < output_label->dims->data[0]; i++) { + if (i >= sorted_labels.size() || + sorted_labels[i].second < option->weight_threshold) { + // Set -1 to avoid lookup message with id 0, which is set for backoff. + output_label->data.i32[i] = -1; + output_weight->data.f[i] = 0.0f; + } else { + output_label->data.i32[i] = sorted_labels[i].first; + output_weight->data.f[i] = sorted_labels[i].second; + } + } + + return kTfLiteOk; +} + +} // namespace predict + +TfLiteRegistration* Register_PREDICT() { + static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare, + predict::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc new file mode 100644 index 0000000000..e97c58cbd1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_PREDICT(); + +namespace { + +using ::testing::ElementsAreArray; + +class PredictOpModel : public SingleOpModel { + public: + PredictOpModel(std::initializer_list input_signature_shape, + std::initializer_list key_shape, + std::initializer_list labelweight_shape, int num_output, + float threshold) { + input_signature_ = AddInput(TensorType_INT32); + model_key_ = AddInput(TensorType_INT32); + model_label_ = AddInput(TensorType_INT32); + model_weight_ = AddInput(TensorType_FLOAT32); + output_label_ = AddOutput(TensorType_INT32); + output_weight_ = AddOutput(TensorType_FLOAT32); + + std::vector predict_option; + writeInt32(num_output, &predict_option); + writeFloat32(threshold, &predict_option); + SetCustomOp("Predict", predict_option, Register_PREDICT); + BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape, + labelweight_shape}}); + } + + void SetInputSignature(std::initializer_list data) { + PopulateTensor(input_signature_, data); + } + + void SetModelKey(std::initializer_list data) { + PopulateTensor(model_key_, data); + } + + void SetModelLabel(std::initializer_list data) { + PopulateTensor(model_label_, data); + } + + void SetModelWeight(std::initializer_list data) { + PopulateTensor(model_weight_, data); + } + + std::vector GetLabel() { return ExtractVector(output_label_); } + std::vector GetWeight() { + return ExtractVector(output_weight_); + } + + void writeFloat32(float value, std::vector* data) { + union { + float v; + uint8_t r[4]; + } float_to_raw; + float_to_raw.v = value; + for (unsigned char i : float_to_raw.r) { + data->push_back(i); + } + } + + void writeInt32(int32_t value, std::vector* data) { + union { + int32_t v; + uint8_t r[4]; + } int32_to_raw; + int32_to_raw.v = value; + for (unsigned char i : int32_to_raw.r) { + data->push_back(i); + } + } + + private: + int input_signature_; + int model_key_; + int model_label_; + int model_weight_; + int output_label_; + int output_weight_; +}; + +TEST(PredictOpTest, AllLabelsAreValid) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05}))); +} + +TEST(PredictOpTest, MoreLabelsThanRequired) { + PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1}))); +} + +TEST(PredictOpTest, OneLabelDoesNotPassThreshold) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0}))); +} + +TEST(PredictOpTest, NoneLabelPassThreshold) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); +} + +TEST(PredictOpTest, OnlyOneLabelGenerated) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0}); + m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0}))); +} + +TEST(PredictOpTest, NoLabelGenerated) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({5, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0}); + m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc new file mode 100644 index 0000000000..a28222213e --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +#include "absl/strings/str_split.h" +#include "re2/re2.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +namespace tflite { +namespace custom { +namespace smartreply { + +// Split sentence into segments (using punctuation). +std::vector SplitSentence(const string& input) { + string result(input); + + RE2::GlobalReplace(&result, "([?.!,])+", " \\1"); + RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t"); + RE2::GlobalReplace(&result, "[ ]+", " "); + RE2::GlobalReplace(&result, "\t+$", ""); + + return strings::Split(result, '\t'); +} + +// Predict with TfLite model. +void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter, + std::map* response_map) { + { + TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]); + tflite::DynamicBuffer buf; + buf.AddString(sentence.data(), sentence.length()); + buf.WriteToTensor(input); + interpreter->AllocateTensors(); + + interpreter->Invoke(); + + TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]); + TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]); + + for (int i = 0; i < confidence->dims->data[0]; i++) { + float weight = confidence->data.f[i]; + auto response_text = tflite::GetString(messages, i); + if (response_text.len > 0) { + (*response_map)[string(response_text.str, response_text.len)] += weight; + } + } + } +} + +void GetSegmentPredictions( + const std::vector& input, const ::tflite::FlatBufferModel& model, + const SmartReplyConfig& config, + std::vector* predictor_responses) { + // Initialize interpreter + std::unique_ptr<::tflite::Interpreter> interpreter; + ::tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); + ::tflite::InterpreterBuilder(model, resolver)(&interpreter); + + if (!model.initialized()) { + fprintf(stderr, "Failed to mmap model \n"); + return; + } + + // Execute Tflite Model + std::map response_map; + std::vector sentences; + for (const string& str : input) { + std::vector splitted_str = SplitSentence(str); + sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end()); + } + for (const auto& sentence : sentences) { + ExecuteTfLite(sentence, interpreter.get(), &response_map); + } + + // Generate the result. + for (const auto& iter : response_map) { + PredictorResponse prediction(iter.first, iter.second); + predictor_responses->emplace_back(prediction); + } + std::sort(predictor_responses->begin(), predictor_responses->end(), + [](const PredictorResponse& a, const PredictorResponse& b) { + return a.GetScore() > b.GetScore(); + }); + + // Add backoff response. + for (const string& backoff : config.backoff_responses) { + if (predictor_responses->size() >= config.num_response) { + break; + } + predictor_responses->push_back({backoff, config.backoff_confidence}); + } +} + +} // namespace smartreply +} // namespace custom +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h new file mode 100644 index 0000000000..3b9a2b32e1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ + +#include +#include + +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace custom { +namespace smartreply { + +const int kDefaultNumResponse = 10; +const float kDefaultBackoffConfidence = 1e-4; + +class PredictorResponse; +struct SmartReplyConfig; + +// With a given string as input, predict the response with a Tflite model. +// When config.backoff_response is not empty, predictor_responses will be filled +// with messagees from backoff response. +void GetSegmentPredictions(const std::vector& input, + const ::tflite::FlatBufferModel& model, + const SmartReplyConfig& config, + std::vector* predictor_responses); + +// Data object used to hold a single predictor response. +// It includes messages, and confidence. +class PredictorResponse { + public: + PredictorResponse(const string& response_text, float score) { + response_text_ = response_text; + prediction_score_ = score; + } + + // Accessor methods. + const string& GetText() const { return response_text_; } + float GetScore() const { return prediction_score_; } + + private: + string response_text_ = ""; + float prediction_score_ = 0.0; +}; + +// Configurations for SmartReply. +struct SmartReplyConfig { + // Maximum responses to return. + int num_response; + // Default confidence for backoff responses. + float backoff_confidence; + // Backoff responses are used when predicted responses cannot fulfill the + // list. + const std::vector& backoff_responses; + + SmartReplyConfig(std::vector backoff_responses) + : num_response(kDefaultNumResponse), + backoff_confidence(kDefaultBackoffConfidence), + backoff_responses(backoff_responses) {} +}; + +} // namespace smartreply +} // namespace custom +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc new file mode 100644 index 0000000000..2fa9923bc9 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +#include +#include + +#include "base/logging.h" +#include +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace custom { +namespace smartreply { +namespace { + +const char kModelName[] = "smartreply_ondevice_model.bin"; +const char kSamples[] = "smartreply_samples.tsv"; + +MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { + bool has_expected_response = false; + for (const auto &item : *arg) { + const string &response = item.GetText(); + if (expected_response.find(response) != expected_response.end()) { + has_expected_response = true; + break; + } + } + return has_expected_response; +} + +class PredictorTest : public ::testing::Test { + protected: + PredictorTest() { + model_ = tflite::FlatBufferModel::BuildFromFile( + StrCat(TestDataPath(), "/", kModelName).c_str()); + CHECK(model_); + } + ~PredictorTest() override {} + + std::unique_ptr<::tflite::FlatBufferModel> model_; +}; + +TEST_F(PredictorTest, GetSegmentPredictions) { + std::vector predictions; + + GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions); + EXPECT_GT(predictions.size(), 0); + + float max = 0; + for (const auto &item : predictions) { + LOG(INFO) << "Response: " << item.GetText(); + if (item.GetScore() > max) { + max = item.GetScore(); + } + } + + EXPECT_GT(max, 0.3); + EXPECT_THAT( + &predictions, + IncludeAnyResponesIn(std::unordered_set({"Thanks very much"}))); +} + +TEST_F(PredictorTest, TestTwoSentences) { + std::vector predictions; + + GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}}, + &predictions); + EXPECT_GT(predictions.size(), 0); + + float max = 0; + for (const auto &item : predictions) { + LOG(INFO) << "Response: " << item.GetText(); + if (item.GetScore() > max) { + max = item.GetScore(); + } + } + + EXPECT_GT(max, 0.3); + EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set( + {"Hi, how are you doing?"}))); +} + +TEST_F(PredictorTest, TestBackoff) { + std::vector predictions; + + GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions); + EXPECT_EQ(predictions.size(), 0); + + // Backoff responses are returned in order. + GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}}, + &predictions); + EXPECT_EQ(predictions.size(), 2); + EXPECT_EQ(predictions[0].GetText(), "Yes"); + EXPECT_EQ(predictions[1].GetText(), "Ok"); +} + +TEST_F(PredictorTest, BatchTest) { + int total_items = 0; + int total_responses = 0; + int total_triggers = 0; + + string line; + std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + while (std::getline(fin, line)) { + const std::vector &fields = strings::Split(line, '\t'); + if (fields.empty()) { + continue; + } + + // Parse sample file and predict + const string &msg = fields[0]; + std::vector predictions; + GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions); + + // Validate response and generate stats. + total_items++; + total_responses += predictions.size(); + if (!predictions.empty()) { + total_triggers++; + } + EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set( + fields.begin() + 1, fields.end()))); + } + + LOG(INFO) << "Responses: " << total_responses << " / " << total_items; + LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items; + EXPECT_EQ(total_triggers, total_items); +} + +} // namespace +} // namespace smartreply +} // namespace custom +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc new file mode 100644 index 0000000000..f5d1f436bc --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech Hotword model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +void RunTest(int model_input_tensor, int svdf_layer_state_tensor, + int model_output_tensor, const string& model_name, + const string& golden_in_name, const string& golden_out_name) { + // Read the model. + string tflite_file_path = file::JoinPath(TestDataPath(), model_name); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Reset the SVDF layer state. + memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0, + interpreter->tensor(svdf_layer_state_tensor)->bytes); + + // Load the input frames. + Frames input_frames; + const string input_file_path = file::JoinPath(TestDataPath(), golden_in_name); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), golden_out_name); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(model_input_tensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(model_input_tensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(model_output_tensor)->dims->data[1]; + const int input_sequence_size = + input_frames[0].size() / (speech_input_size * speech_batch_size); + float* input_ptr = interpreter->tensor(model_input_tensor)->data.f; + float* output_ptr = interpreter->tensor(model_output_tensor)->data.f; + + // The first layer (SVDF) input size is 40 (speech_input_size). Each speech + // input frames for this model is 1280 floats, which can be fed to input in a + // sequence of size 32 (input_sequence_size). + for (int i = 0; i < TestInputSize(input_frames); i++) { + int frame_ptr = 0; + for (int s = 0; s < input_sequence_size; s++) { + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + interpreter->Invoke(); + } + // After the whole frame (1280 floats) is fed, we can check the output frame + // matches with the golden output frame. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +TEST(SpeechHotword, OkGoogleTestRank1) { + constexpr int kModelInputTensor = 0; + constexpr int kSvdfLayerStateTensor = 4; + constexpr int kModelOutputTensor = 18; + + RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, + "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank1.csv"); +} + +TEST(SpeechHotword, OkGoogleTestRank2) { + constexpr int kModelInputTensor = 17; + constexpr int kSvdfLayerStateTensor = 1; + constexpr int kModelOutputTensor = 18; + RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, + "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank2.csv"); +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc new file mode 100644 index 0000000000..687cfab0b2 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech SpeakerId model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 19; +constexpr int kLstmLayer1CellStateTensor = 20; +constexpr int kLstmLayer2OutputStateTensor = 40; +constexpr int kLstmLayer2CellStateTensor = 41; +constexpr int kLstmLayer3OutputStateTensor = 61; +constexpr int kLstmLayer3CellStateTensor = 62; +constexpr int kModelOutputTensor = 66; + +TEST(SpeechSpeakerId, OkGoogleTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_speakerid_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ::tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); + std::unique_ptr interpreter; + InterpreterBuilder(*model, resolver)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_speakerid_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_speakerid_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc new file mode 100644 index 0000000000..30d89a1354 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech TERSE AM model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 19; +constexpr int kLstmLayer1CellStateTensor = 20; +constexpr int kLstmLayer2OutputStateTensor = 40; +constexpr int kLstmLayer2CellStateTensor = 41; +constexpr int kLstmLayer3OutputStateTensor = 61; +constexpr int kLstmLayer3CellStateTensor = 62; +constexpr int kLstmLayer4OutputStateTensor = 82; +constexpr int kLstmLayer4CellStateTensor = 83; +constexpr int kLstmLayer5OutputStateTensor = 103; +constexpr int kLstmLayer5CellStateTensor = 104; +constexpr int kModelOutputTensor = 109; + +TEST(SpeechTerseAm, RandomIOTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to mmap model " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer4CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer5CellStateTensor)->bytes); + + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc new file mode 100644 index 0000000000..e6f2673a42 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_tts_model_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech TTS model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 25; +constexpr int kLstmLayer1CellStateTensor = 26; +constexpr int kLstmLayer2OutputStateTensor = 46; +constexpr int kLstmLayer2CellStateTensor = 47; +constexpr int kLstmLayer3OutputStateTensor = 67; +constexpr int kLstmLayer3CellStateTensor = 68; +constexpr int kRnnLayerHiddenStateTensor = 73; +constexpr int kModelOutputTensor = 74; + +TEST(SpeechTTS, RandomIOTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_tts_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to mmap model " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_tts_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_tts_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + + memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0, + interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes); + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h new file mode 100644 index 0000000000..b2596babd0 --- /dev/null +++ b/tensorflow/contrib/lite/models/test_utils.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tflite { +namespace models { +using Frames = std::vector>; +} // namespace models +} // namespace tflite + +#ifndef __ANDROID__ +#include "file/base/path.h" +#include "tensorflow/core/platform/test.h" + +inline string TestDataPath() { + return string(file::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), + "contrib/lite/models/testdata/")); +} +inline int TestInputSize(const tflite::models::Frames& input_frames) { + return input_frames.size(); +} +#else +inline string TestDataPath() { + return string("third_party/tensorflow/contrib/lite/models/testdata/"); +} + +inline int TestInputSize(const tflite::models::Frames& input_frames) { + // Android TAP is very slow, we only test the first 20 frames. + return 20; +} +#endif + +namespace tflite { +namespace models { + +// Read float data from a comma-separated file: +// Each line will be read into a float vector. +// The return result will be a vector of float vectors. +void ReadFrames(const string& csv_file_path, Frames* frames) { + std::ifstream csv_file(csv_file_path); + string line; + while (std::getline(csv_file, line, '\n')) { + std::vector fields; + // Used by strtok_r internaly for successive calls on the same string. + char* save_ptr = nullptr; + + // Tokenize the line. + char* next_token = + strtok_r(const_cast(line.c_str()), ",", &save_ptr); + while (next_token != nullptr) { + float f = strtod(next_token, nullptr); + fields.push_back(f); + next_token = strtok_r(nullptr, ",", &save_ptr); + } + frames->push_back(fields); + } + csv_file.close(); +} + +} // namespace models +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/lite/nnapi/BUILD b/tensorflow/contrib/lite/nnapi/BUILD new file mode 100644 index 0000000000..402f1e949b --- /dev/null +++ b/tensorflow/contrib/lite/nnapi/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [ + "//visibility:public", +]) + +cc_library( + name = "nnapi_lib", + hdrs = [ + "NeuralNetworksShim.h", + ], + linkopts = ["-ldl"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h new file mode 100644 index 0000000000..5d06165772 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -0,0 +1,1916 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef NN_API_SHIM_H0 +#define NN_API_SHIM_H0 + +#include +#include +#include +#include + +// helpers + +#define NNAPI_LOG(format, ...) printf(format "\n", __VA_ARGS__); +#define LOAD_FUNCTION(name) \ + static name##_fn fn = reinterpret_cast(loadFunction(#name)); +#define EXECUTE_FUNCTION(...) \ + if (fn != nullptr) { \ + fn(__VA_ARGS__); \ + } +#define EXECUTE_FUNCTION_RETURN(...) return fn != nullptr ? fn(__VA_ARGS__) : 0; + +inline void* loadLibrary(const char* name) { + // TODO: change RTLD_LOCAL? Assumes there can be multiple instances of nn + // api RT + void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL); + if (handle == nullptr) { + NNAPI_LOG("nnapi error: unable to open library %s", name); + } + return handle; +} + +inline void* getLibraryHandle() { + static void* handle = loadLibrary("libneuralnetworks.so"); + return handle; +} + +inline void* loadFunction(const char* name) { + void* fn = nullptr; + if (getLibraryHandle() != nullptr) { + fn = dlsym(getLibraryHandle(), name); + } + if (fn == nullptr) { + NNAPI_LOG("nnapi error: unable to open function %s", name); + } + return fn; +} + +inline bool NNAPIExists() { + static bool nnapi_is_available = getLibraryHandle(); + return nnapi_is_available; +} + +// nn api types + +/** + * Operand types. + * + * The type of operands that can be added to a model. + * + * Although we define many types, most operators accept just a few + * types. Most used are ANEURALNETWORKS_TENSOR_FLOAT32, + * ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, and ANEURALNETWORKS_INT32. + */ +enum { + /** The following entries are used to declare scalars. */ + + /** A 32 bit floating point scalar value. */ + ANEURALNETWORKS_FLOAT32 = 0, + /** A signed 32 bit integer scalar value. */ + ANEURALNETWORKS_INT32 = 1, + /** An unsigned 32 bit integer scalar value. */ + ANEURALNETWORKS_UINT32 = 2, + + /** The following entries are used to declare tensors. */ + + /** A tensor of 32 bit floating point values. */ + ANEURALNETWORKS_TENSOR_FLOAT32 = 3, + /** A tensor of 32 bit integer values. */ + ANEURALNETWORKS_TENSOR_INT32 = 4, + /** A tensor of 8 bit integers that represent real numbers. + * + * Attached to this tensor are two numbers that can be used to convert + * the 8 bit integer to the real value and vice versa. These two numbers are: + * - scale: a 32 bit floating point value + * - zero_value: an 32 bit integer + * + * The formula is: + * real_value = (integer_value - zero_value) * scale. + */ + ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, +}; + +/** + * Operation types. + * + * The type of operations that can be added to a model. + */ +enum { + /** Adds two tensors, elment-wise. + * + * Takes two input tensors of identical type and compatible dimensions. The + * output is the sum of both input tensors, optionally modified by an + * activation function. + * + * Two dimensions are compatible when: + * 1. they are equal, or + * 2. one of them is 1 + * + * The size of the output is the maximum size along each dimension of the + * input operands. It starts with the trailing dimensions, and works its way + * forward. + * + * Example: + * + * input1.dimension = {4, 1, 2} + * input2.dimension = {5, 4, 3, 1} + * output.dimension = {5, 4, 3, 2} + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor. + * * 1: A tensor of the same type, and compatible dimensions as input0. + * * 2: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The sum, a tensor of the same type as input0. + */ + ANEURALNETWORKS_ADD = 0, + /** Performs a 2-D average pooling operation. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * sum_{i, j}(input[batch, row + i, col + j, channel]) / sum(1) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 6: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the filter width. + * * 8: An INT32 value, specifying the filter height. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_AVERAGE_POOL_2D = 1, + /** Concatenates the input tensors along the given dimension. + * + * The input tensors must have identical type and the same dimensions except + * the dimension along the concatenation axis. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * 0 ~ n: The list on n input tensors, of shape [D0, D1, ..., Daxis(i), ..., + * Dm] n+1: An INT32 value, specifying the concatenation axis. n+2: An INT32 + * value, and has to be one of the {@link FuseCode} values. Specifies the + * activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output, a tensor of the same type as the input tensors. + * The output shape is [D0, D1, ..., sum(Daxis(i)), ..., Dm]. + */ + ANEURALNETWORKS_CONCATENATION = 2, + /** Performs an 2-D convolution operation. + * + * The CONV_2D op sweeps a 2-D filter that can mix channels together over a + * batch of images, applying the filter to each window of each image of the + * appropriate size. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * sum_{i, j} ( + * input[batch, row + i, col + j, k] * + * filter[channel, row + i, col + j, k] + + * bias[channel] + * ) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width, + * depth_in], specifying the filter. + * * 2: A 1-D tensor, of shape [depth_out], specifying the bias. + * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the + * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input + * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should + * be of {@link ANEURALNETWORKS_TENSOR_INT32}. + * * 3: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 4: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 5: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 8: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth_out]. + */ + ANEURALNETWORKS_CONV_2D = 3, + /** Performs a depthwise 2-D convolution operation. + * + * Given an input tensor of shape [batches, height, width, depth_in] and a + * filter tensor of shape [depth_out, filter_height, filter_width, depth_in] + * containing in_channels convolutional filters of depth 1, DEPTHWISE_CONV + * applies a different filter to each input channel (expanding from 1 channel + * to channel_multiplier channels for each), then concatenates the results + * together. + * + * The output has depth_out = depth_in * depth_multiplier channels. + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[b, i, j, k * channel_multiplier + q] = + * sum_{di, dj} ( + * input[b, strides[1] * i + di, strides[2] * j + dj, k] * + * filter[di, dj, k, q] + * ) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width, + * depth_in], specifying the filter. + * * 2: A 1-D tensor, of shape [depth_out], specifying the bias. + * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the + * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input + * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should + * be of {@link ANEURALNETWORKS_TENSOR_INT32}. + * * 3: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 4: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 5: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 8: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 9: An INT32 value, specifying the depthwise multiplier. + * * 10: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth_out]. + */ + ANEURALNETWORKS_DEPTHWISE_CONV_2D = 4, + /** Rearranges data from depth into blocks of spatial data. + * + * More specifically, this op outputs a copy of the input tensor where values + * from the depth dimension are moved in spatial blocks to the height and + * width dimensions. The value block_size indicates the input block size and + * how the data is moved. + * + * Chunks of data of size block_size * block_size from depth are rearranged + * into non-overlapping blocks of size block_size x block_size. + * + * The width of the output tensor is input_depth * block_size, whereas the + * height is input_height * block_size. The depth of the input tensor must be + * divisible by block_size * block_size + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and + * block_size * block_size must be a divisor of the input depth. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batch, height*block_size, + * width*block_size, depth/(block_size*block_size)]. + */ + ANEURALNETWORKS_DEPTH_TO_SPACE = 5, + /** Dequantizes the input tensor. + * + * The formula is: + * + * output = (input - zero_value) * scale. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor of type {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}. + * + * Outputs: + * * 0: The output tensor of same shape as input0, but with type + * {@link ANEURALNETWORKS_TENSOR_FLOAT32}. + */ + ANEURALNETWORKS_DEQUANTIZE = 6, + + /** + * Looks up items from a given tensor. + * + * Each item in the output is a raw copy of the corresponding item in + * the input “values”. If the the given “lookup” indices are out of bounds, + * the op will fail and an error will be reported. + * + * Inputs: + * * 0: Values. An n-D tensor of any type X (where n >= 2). E.g., if n is 2, + * then the shape would be [lookup_dimension, values_dimension], where + * “lookup_dimension” corresponds to the indexing dimension in the lookup + * table, and “values_dimension” to the contents. + * * 1: Lookups. An 1-D tensor of type T, of shape [lookup_size], where + * “lookup_size” is the number of elements to look for, and each entry + * corresponds to the first dimension of the “values” tensor. + * + * Output: + * * 0: A n-D tensor of type X and the same rank and shape as the “values” + * tensor, except for the first dimension which has size “lookup_size”. + */ + ANEURALNETWORKS_EMBEDDING_LOOKUP = 7, + + /** Computes element-wise floor() on the input tensor. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor. + * + * Outputs: + * * 0: The output, a tensor of the same type and dimensions as input0. + */ + ANEURALNETWORKS_FLOOR = 8, + /** Denotes a fully (densely) connected layer, which connects all elements in + * the input tensor with each element in the output tensor. + * + * This layer implements the operation: + * + * outputs = activation(inputs * weights’ + bias) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. If rank is greater than 2, then it + * gets flattened to a 2-D Tensor. The 2-D Tensor is handled as if dimensions + * corresponded to shape [batch_size, input_size], where “batch_size” + * corresponds to the batching dimension, and “input_size” is the size of the + * input. + * * 1: A 2-D tensor, specifying the weights, of shape [num_units, + * input_size], where "num_units" corresponds to the number of output nodes. + * * 2: A 1-D tensor, of shape [num_units], specifying the bias. + * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the + * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input + * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should + * be of {@link ANEURALNETWORKS_TENSOR_INT32}. + * * 3: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output tensor, of shape [batch_size, num_units]. + */ + ANEURALNETWORKS_FULLY_CONNECTED = 9, + + /** + * Looks up values of a hash table with given keys. + * + * Inputs: + * * 0: Lookups. A 1-D int32 tensor with shape [ k ]. + * * 1: Keys. A 1-D int32 tensor with shape [ n ], *MUST* be sorted in + * ascending order. + * * 2: Values. A tensor with shape [ n … ]. + * + * Outputs: + * * 0: Output. A tensor with shape [ k …]. + * * 1: Hits. A uint8 tensor with shape [ k ] indicates whether the lookup + * hits or not. + */ + ANEURALNETWORKS_HASHTABLE_LOOKUP = 10, + + /** Applies L2 normalization along the depth dimension. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * input[batch, row, col, channel] / + * sqrt(sum_{c} pow(input[batch, row, col, c], 2)) + * + * For x with more dimensions, independently normalizes each 1-D slice along + * dimension dim. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_L2_NORMALIZATION = 11, + + /** Performs an 2-D L2 pooling operation. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * sqrt(sum_{i, j} pow(input[batch, row + i, col + j, channel], 2) / + * sum(1)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 6: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the filter width. + * * 8: An INT32 value, specifying the filter height. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_L2_POOL_2D = 12, + /** Applies Local Response Normalization along the depth dimension. + * + * The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the + * last dimension), and each vector is normalized independently. Within a + * given vector, each component is divided by the weighted, squared sum of + * inputs within depth_radius. + * + * The output is calculated using this formula: + * + * sqr_sum[a, b, c, d] = + * sum(pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2) + * output = input / pow((bias + alpha * sqr_sum), beta) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the radius of the normalization window. + * * 2: A FLOAT32 value, specifying the bias, must not be zero. + * * 3: A FLOAT32 value, specifying the scale factor, alpha. + * * 4: A FLOAT32 value, specifying the exponent, beta. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION = 13, + /** Computes sigmoid activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = 1 / (1 + exp(-input)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_LOGISTIC = 14, + + /** + * Projects an input to a bit vector via locality senstive hashing. + * + * Inputs: + * * 0: Hash functions. Dim.size == 2, DataType: Float. + * Tensor[0].Dim[0]: Number of hash functions. + * Tensor[0].Dim[1]: Number of seeds per hash functions. + * Tensor[0].Dim[1] <= 32 in sparse case. + * + * * 1: Input. Dim.size >= 1, no restriction on DataType. + * * 2: Weight. Optional. Dim.size == 1, DataType: Float. + * If not set, each input element is considered to have the same weight of + * 1.0. + * Tensor[1].Dim[0] == Tensor[2].Dim[0] + * * 3: Type: + * Sparse: Value LSHProjectionType_SPARSE(=1). + * Computed bit vector is considered to be sparse. + * Each output element is an int32 made up of multiple bits computed + * from hash functions. + * + * Dense: Value LSHProjectionType_DENSE(=2). + * Computed bit vector is considered to be dense. Each output element + * represents a bit and can take the value of either 0 or 1. + * + * Outputs: + * * 0: If the projection type is sparse: + * Output.Dim == { Tensor[0].Dim[0] } + * A tensor of int32 that represents hash signatures. + * If the projection type is Dense: + * Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } + * A flattened tensor that represents projected bit vectors. + */ + ANEURALNETWORKS_LSH_PROJECTION = 15, + + /** + * Long short-term memory unit (LSTM) recurrent network layer. + * + * The default non-peephole implementation is based on: + * http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf + * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural + * Computation, 9(8):1735-1780, 1997. + * + * The peephole implementation is based on: + * https://research.google.com/pubs/archive/43905.pdf + * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory + * recurrent neural network architectures for large scale acoustic modeling." + * INTERSPEECH, 2014. + * + * The coupling of input and forget gate (CIFG) is based on: + * http://arxiv.org/pdf/1503.04069.pdf + * Greff et al. "LSTM: A Search Space Odyssey" + * + * The class has the following independently optional inputs: + * * If input gate (if CIFG): “input_to_forget_weights”, + * “recurrent_to_input_weights”, “cell_to_input_weights”, “input_gate_bias”. + * * If no peephole connections: “cell_to_input_weights”, + * “cell_to_forget_weights”, “cell_to_output_weights”. + * * If no projection layer: “projection_weights” and “projection_bias”. + * * If no projection bias: “projection_bias”. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Inputs: + * * 0: Input. + * A 2-D tensor of type T, of shape [batch_size, input_size], where + * “batch_size” corresponds to the batching dimension, and “input_size” + * is the size of the input. + * * 1: input_to_input_weights. + * A 2-D tensor of type T, of shape [num_units, input_size], where + * “num_units” corresponds to the number of cell units. + * * 2: input_to_forget_weights. + * A 2-D tensor of type T, of shape [num_units, input_size]. + * * 3: input_to_cell_weights. + * A 2-D tensor of type T, of shape [num_units, input_size]. + * * 4: input_to_output_weights. + * A 2-D tensor of type T, of shape [num_units, input_size]. + * * 5: recurrent_to_input_weights. + * A 2-D tensor of type T, of shape [num_units, output_size], where + * “output_size” corresponds to either the number of cell units (i.e., + * “num_units”), or the second dimension of the “projection_weights”, if + * defined. + * * 6: recurrent_to_forget_weights. + * A 2-D tensor of type T, of shape [num_units, output_size]. + * * 7: recurrent_to_cell_weights. + * A 2-D tensor of type T, of shape [num_units, output_size]. + * * 8: recurrent_to_output_weights. + * A 2-D tensor of type T, of shape [num_units, output_size]. + * * 9: cell_to_input_weights. + * A 1-D tensor of type T, of shape [num_units]. + * * 10:cell_to_forget_weights. + * A 1-D tensor of type T, of shape [num_units]. + * * 11:cell_to_output_weights. + * A 1-D tensor of type T, of shape [num_units]. + * * 12:input_gate_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 13:forget_gate_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 14:cell_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 15:output_gate_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 16:projection_weights. + * A 2-D tensor of type T, of shape [output_size, num_units]. + * * 17:projection_bias. + * A 1-D tensor of type T, of shape [output_size]. + * + * Parameters: + * * 18:fused_activation_function. + * An (optional) ActivationFunctionType indicating the activation + * function. + * If “NONE” is specified then it results in a linear activation. + * * 19:cell_clip. + * A clipping threshold for the cell state, such that values are bound + * within [-cell_clip, cell_clip]. If set to 0.0 then clipping is + * disabled. + * * 20:proj_clip. + * A clipping threshold for the output from the projection layer, such + * that values are bound within [-proj_clip, proj_clip]. If set to 0.0 + * then clipping is disabled. + * + * Outputs: + * * 0: scratch_buffer. + * A 3-D tensor of type T, of shape [batch_size, num_cell, 4]. + * * 1: output_state. + * A 2-D tensor of type T, of shape [batch_size, output_size]. + * * 2: cell_state. + * A 2-D tensor of type T, of shape [batch_size, num_units]. + * * 3: output. + * A 2-D tensor of type T, of shape [batch_size, output_size]. This is + * effectively the same as the current “output_state” value. + */ + ANEURALNETWORKS_LSTM = 16, + + /** Performs an 2-D max pooling operation. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * max_{i, j} (input[batch, row + i, col + j, channel]) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 6: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the filter width. + * * 8: An INT32 value, specifying the filter height. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_MAX_POOL_2D = 17, + + /** Multiplies two tensors, elment-wise. + * + * Takes two input tensors of identical type and compatible dimensions. The + * output is the product of both input tensors, optionally modified by an + * activation function. + * + * Two dimensions are compatible when: + * 1. they are equal, or + * 2. one of them is 1 + * + * The size of the resulting output is the maximum size along each dimension + * of the input operands. It starts with the trailing dimensions, and works + * its way forward. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor. + * * 1: A tensor of the same type, and compatible dimensions as input0. + * * 2: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The product, a tensor of the same type as input0. + */ + ANEURALNETWORKS_MUL = 18, + /** Computes rectified linear activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = max(0, input) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_RELU = 19, + /** Computes rectified linear 1 activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = min(1.f, max(-1.f, input)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_RELU1 = 20, + /** Computes rectified linear 6 activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = min(6, max(0, input)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_RELU6 = 21, + /** Reshapes a tensor. + * + * Given tensor, this operation returns a tensor that has the same values as + * tensor, but with a newly specified shape. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the tensor to be reshaped. + * * 1: A 1-D tensor of type {@link ANEURALNETWORKS_TENSOR_INT32}, defining + * the shape of the output tensor. The number of elements implied by shape + * must be the same as the number of elements in the input tensor. + * + * Outputs: + * * 0: The output tensor, of shape specified by the input shape. + */ + ANEURALNETWORKS_RESHAPE = 22, + /** Resizes images to given size using the bilinear interpretation. + * + * Resized images will be distorted if their original aspect ratio is not the + * same as input. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the output width of the output tensor. + * * 2: An INT32 value, specifying the output height of the output tensor. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, new_height, new_width, + * depth]. + */ + ANEURALNETWORKS_RESIZE_BILINEAR = 23, + + /** + * A basic recurrent neural network layer. + * + * This layer implements the operation: + * outputs = state = activation(inputs * input_weights + state * + * recurrent_weights + bias) + * + * Where: + * * “input_weights” is a weight matrix that multiplies the inputs; + * * “recurrent_weights” is a weight matrix that multiplies the current + * “state” which itself is the output from the previous time step + * computation; + * * “bias” is a bias vector (added to each output vector in the batch); + * * “activation” is the function passed as the “fused_activation_function” + * argument (if not “NONE”). + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Inputs: + * * 0: input. + * A 2-D tensor of type T, of shape [batch_size, input_size], where + * “batch_size” corresponds to the batching dimension, and “input_size” + * is the size of the input. + * * 1: weights. + * A 2-D tensor of type T, of shape [num_units, input_size], where + * “num_units” corresponds to the number of units. + * * 2: recurrent_weights. + * A 2-D tensor of type T, of shape [num_units, num_units], with columns + * corresponding to the weights from each unit. + * * 3: bias. + * A 1-D tensor of type T, of shape [num_units]. + * + * For FLOAT32 input tensor, bias must also be FLOAT32. + * For UINT8 input tensor, bias must be INT32. + * + * Parameters + * * 4: fused_activation_function. + * An (optional) ActivationFunctionType indicating the activation + * function. If “NONE” is specified then it results in a linear + * activation. + * + * * 5: Hidden state. + * A 2-D tensor of type T, of shape [batch_size, num_units]. + * + * Outputs: + * * 0: output. + * A 2-D tensor of type T, of shape [batch_size, num_units]. This is + * effectively the same as the current state value. + */ + ANEURALNETWORKS_RNN = 24, + + /** Computes the softmax activation on the input tensor element-wise, per + * batch, by normalizing the input vector so the maximum coefficient is zero. + * + * The output is calculated using this formula: + * + * output[batch, i] = + * exp((input[batch, i] - max(input[batch, :])) * beta) / + * sum_{k}{exp((input[batch, k] - max(input[batch, :])) * beta)} + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 2 or 4. + * + * Inputs: + * * 0: A 2-D or 4-D tensor, specifying the tensor to be reshaped. + * * 1: A FLOAT32 value, specifying the scaling factor for the exponent, beta. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_SOFTMAX = 25, + + /** Rearranges blocks of spatial data, into depth. + * + * More specifically, this op outputs a copy of the input tensor where values + * from the height and width dimensions are moved to the depth dimension. The + * value block_size indicates the input block size and how the data is moved. + * + * Chunks of data of size block_size * block_size from depth are rearranged + * into non-overlapping blocks of size block_size x block_size. + * + * The depth of the output tensor is input_depth * block_size * block_size. + * The input tensor's height and width must be divisible by block_size. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and + * block_size must be a divisor of both the input height and width. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batch, height/block_size, + * width/block_size, depth*block_size*block_size]. + */ + ANEURALNETWORKS_SPACE_TO_DEPTH = 26, + + /** + * SVDF op is a kind of stateful layer derived from the notion that a + * densely connected layer that's processing a sequence of input frames can + * be approximated by using a singular value decomposition of each of its + * nodes. The implementation is based on: + * + * https://research.google.com/pubs/archive/43813.pdf + * + * P. Nakkiran, R. Alvarez, R. Prabhavalkar, C. Parada. + * “Compressing Deep Neural Networks using a Rank-Constrained Topology”. + * INTERSPEECH, 2015. + * + * It processes the incoming input using a 2-stage filtering mechanism: + * * stage 1 performs filtering on the "features" dimension, whose outputs get + * pushed into a memory of fixed-size memory_size. + * * stage 2 performs filtering on the "time" dimension of the memory_size + * memoized outputs of stage 1. + * + * Specifically, for rank 1, this layer implements the operation: + * + * memory = push(conv1d(inputs, weights_feature, feature_dim, "VALID")); + * outputs = activation(memory * weights_time + bias); + * + * Where: + * * “weights_feature” is a weights matrix that processes the inputs (by + * convolving the input with every “feature filter”), and whose outputs get + * pushed, stacked in order, into the fixed-size “memory” (the oldest entry + * gets dropped); + * * “weights_time” is a weights matrix that processes the “memory” (by a + * batched matrix multiplication on the num_units); + * * “bias” is an optional bias vector (added to each output vector in the + * batch); and + * * “activation” is the function passed as the “fused_activation_function” + * argument (if not “NONE”). + * + * Each rank adds a dimension to the weights matrices by means of stacking + * the filters. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Inputs: + * * 0: input. + * A 2-D tensor of type T, of shape [batch_size, input_size], where + * “batch_size” corresponds to the batching dimension, and “input_size” + * is the size of the input. + * * 1: weights_feature. + * A 2-D tensor of type T, of shape [num_units, input_size], where + * “num_units” corresponds to the number of units. + * * 2: weights_time. + * A 2-D tensor of type T, of shape [num_units, memory_size], where + * “memory_size” corresponds to the fixed-size of the memory. + * * 3: bias. + * A optional 1-D tensor of type T, of shape [num_units]. + * + * For FLOAT32 input tensor, bias must also be FLOAT32. + * For UINT8 input tensor, bias must be INT32. + * + * Parameters: + * * 4: rank. + * The rank of the SVD approximation. + * * 5: fused_activation_function. + * An (optional) ActivationFunctionType indicating the activation + * function. If “NONE” is specified then it results in a linear activation. + * + * Outputs: + * * 0: state. + * A 2-D tensor of type T, of shape [batch_size, (memory_size - 1) * + * num_units * rank]. + * * 1: output. + * A 2-D tensor of type T, of shape [batch_size, num_units]. + */ + ANEURALNETWORKS_SVDF = 27, + + /** Computes hyperbolic tangent of input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = tanh(input) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_TANH = 28, +}; + +/** + * Fused activation function types. + * + */ +enum { + /** NO fused activation function. */ + ANEURALNETWORKS_FUSED_NONE = 0, + /** Fused ReLU activation function. */ + ANEURALNETWORKS_FUSED_RELU = 1, + /** Fused ReLU1 activation function. */ + ANEURALNETWORKS_FUSED_RELU1 = 2, + /** Fused ReLU6 activation function. */ + ANEURALNETWORKS_FUSED_RELU6 = 3, +}; + +/** + * Execution preferences. + */ +enum { + /** + * Prefer executing in a way that minimizes battery drain. + * This is desirable for compilations that will be executed often. + */ + ANEURALNETWORKS_PREFER_LOW_POWER = 0, + /** + * Prefer returning a single answer as fast as possible, even if this causes + * more power consumption. + */ + ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1, + /** + * Prefer maximizing the throughput of successive frames, for example when + * processing successive frames coming from the camera. + */ + ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2, +}; + +/** + * Result codes. + */ +enum { + ANEURALNETWORKS_NO_ERROR = 0, + ANEURALNETWORKS_OUT_OF_MEMORY = 1, + ANEURALNETWORKS_INCOMPLETE = 2, + ANEURALNETWORKS_UNEXPECTED_NULL = 3, + ANEURALNETWORKS_BAD_DATA = 4, + ANEURALNETWORKS_OP_FAILED = 5, + ANEURALNETWORKS_UNMAPPABLE = 5, + ANEURALNETWORKS_BAD_STATE = 6, +}; + +/** + * ANeuralNetworksMemory is an opaque type that represents memory. + * + * This type is used to represent shared memory, memory mapped files, + * and similar memories. + * + * By using shared memory, a program can efficiently communicate to the + * runtime and drivers the tensors that define a model. See + * {@link ANeuralNetworksModel_setOperandValueFromMemory}. An application + * should typically create one shared memory object that contains every tensor + * needed to define a model. {@link ANeuralNetworksMemory_createFromFd} can be + * used to create shared memory from a file handle. {@link + * ANeuralNetworksMemory_createShared} can be used to directly created shared + * memory. + * + * Memory objects can also be used to specify the input and output arguments of + * an execution. See {@link ANeuralNetworksExecution_setInputFromMemory} + * and {@link ANeuralNetworksExecution_setOutputFromMemory}. + */ +typedef struct ANeuralNetworksMemory ANeuralNetworksMemory; + +/** + * ANeuralNetworksModel is an opaque type that contains a description of the + * mathematical operations that constitute the model. + * + *

The model will be built by calling

    + *
  • {@link ANeuralNetworksModel_create},
  • + *
  • {@link ANeuralNetworksModel_addOperation},
  • + *
  • {@link ANeuralNetworksModel_addOperand},
  • + *
+ * + * A model is completed by calling {@link ANeuralNetworksModel_finish}. + * A model is destroyed by calling {@link ANeuralNetworksModel_free}. + * + *

It is the application's responsibility to make sure that only one thread + * modifies a model at a given time. It is however safe for more than one + * thread to use the model once {@link ANeuralNetworksModel_finish} has + * returned.

+ * + *

It is also the application's responsibility to ensure that there are no + * other uses of the model after calling {@link ANeuralNetworksModel_free}. This + * includes any compilation or execution object created using the model.

+ */ +typedef struct ANeuralNetworksModel ANeuralNetworksModel; + +/** + * ANeuralNetworksCompilation is an opaque type that can be used to compile + * a machine learning model. + * + *

To use:

    + *
  • Create a new compilation instance by calling the + * {@link ANeuralNetworksCompilation_create} function.
  • + *
  • Perform the compilation with {@link + * ANeuralNetworksCompilation_start}.
  • Wait for the compilation to + * complete with {@link ANeuralNetworksCompilation_wait}.
  • Use the + * compilation as many times as needed with {@link + * ANeuralNetworksExecution_create}.
  • Destroy the compilation with + * {@link ANeuralNetworksCompilation_free} once all executions using the + * compilation have completed.

+ * + *

A compilation cannot be modified once {@link + * ANeuralNetworksCompilation_start} has been called on it.

+ * + *

It is the application's responsibility to make sure that only one thread + * modifies a compilation at a given time. It is however safe for more than one + * thread to use {@link ANeuralNetworksCompilation_wait} at the same time. + * It is also safe for multiple threads to use a compilation object once + * {@link ANeuralNetworksCompilation_wait} has completed.

+ * + *

It is also the application's responsibility to ensure that there are no + * other uses of the compilation after calling {@link + * ANeuralNetworksCompilation_free}. This includes any execution object created + * using the compilation.

+ */ +typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation; + +/** + * ANeuralNetworksExecution is an opaque type that can be used to apply a + * machine learning model to a set of inputs. + * + *

To use:

    + *
  • Create a new execution instance by calling the + * {@link ANeuralNetworksExecution_create} function.
  • + *
  • Associate data to the model inputs with + * {@link ANeuralNetworksExecution_setInput} or + * {@link ANeuralNetworksExecution_setInputFromMemory}.
  • + *
  • Associate output buffers to the model outputs with + * {@link ANeuralNetworksExecution_setOutput} or + * {@link ANeuralNetworksExecution_setOutputFromMemory}.
  • + *
  • Apply the model with {@link + * ANeuralNetworksExecution_startCompute}.
  • Wait for the execution to + * complete with {@link ANeuralNetworksExecution_wait}.
  • Destroy the + * execution with + * {@link ANeuralNetworksExecution_free}.

+ * + *

An execution cannot be modified once {@link + * ANeuralNetworksExecution_start} has been called on it.

+ * + *

An execution can be applied to a model with + * {@link ANeuralNetworksExecution_startCompute} only once. Create new + * executions to do new evaluations of the model.

+ * + *

It is the application's responsibility to make sure that only one thread + * modifies an execution at a given time. It is however safe for more than one + * thread to use {@link ANeuralNetworksExecution_wait} at the same time.

+ * + *

It is also the application's responsibility to ensure that there are no + * other uses of the request after calling {@link + * ANeuralNetworksRequest_free}.

+ */ +typedef struct ANeuralNetworksExecution ANeuralNetworksExecution; + +/** + * ANeuralNetworksOperandType describes the type of an operand. + * This structure is used to describe both scalars and tensors. + */ +typedef struct ANeuralNetworksOperandType { + /** The data type, e.g ANEURALNETWORKS_INT8. */ + int32_t type; + /** The number of dimensions. It should be 0 for scalars. */ + uint32_t dimensionCount; + /** The dimensions of the tensor. It should be nullptr for scalars. */ + const uint32_t* dimensions; + /** These two fields are only used for quantized tensors. + * They should be zero for scalars and non-fixed point tensors. + * The dequantized value of each entry is (value - offset) * scale. + */ + float scale; + int32_t zeroPoint; +} ANeuralNetworksOperandType; + +/** + * ANeuralNetworksEvent is an opaque type that represents an event + * that will be signaled once an execution completes. + */ +typedef struct ANeuralNetworksEvent ANeuralNetworksEvent; + +typedef int32_t ANeuralNetworksOperationType; + +// nn api function types + +typedef int (*ANeuralNetworksMemory_createFromFd_fn)( + size_t size, int protect, int fd, size_t offset, + ANeuralNetworksMemory** memory); + +typedef void (*ANeuralNetworksMemory_free_fn)(ANeuralNetworksMemory* memory); + +typedef int (*ANeuralNetworksModel_create_fn)(ANeuralNetworksModel** model); + +typedef int (*ANeuralNetworksModel_finish_fn)(ANeuralNetworksModel* model); + +typedef void (*ANeuralNetworksModel_free_fn)(ANeuralNetworksModel* model); + +typedef int (*ANeuralNetworksCompilation_create_fn)( + ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation); + +typedef void (*ANeuralNetworksCompilation_free_fn)( + ANeuralNetworksCompilation* compilation); + +typedef int (*ANeuralNetworksCompilation_setPreference_fn)( + ANeuralNetworksCompilation* compilation, int32_t preference); + +typedef int (*ANeuralNetworksCompilation_finish_fn)( + ANeuralNetworksCompilation* compilation); + +typedef int (*ANeuralNetworksModel_addOperand_fn)( + ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type); + +typedef int (*ANeuralNetworksModel_setOperandValue_fn)( + ANeuralNetworksModel* model, int32_t index, const void* buffer, + size_t length); + +typedef int (*ANeuralNetworksModel_setOperandValueFromMemory_fn)( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksMemory* memory, size_t offset, size_t length); + +typedef int (*ANeuralNetworksModel_addOperation_fn)( + ANeuralNetworksModel* model, ANeuralNetworksOperationType type, + uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, + const uint32_t* outputs); + +typedef int (*ANeuralNetworksModel_identifyInputsAndOutputs_fn)( + ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, + uint32_t outputCount, const uint32_t* outputs); + +typedef int (*ANeuralNetworksExecution_create_fn)( + ANeuralNetworksCompilation* compilation, + ANeuralNetworksExecution** execution); + +typedef void (*ANeuralNetworksExecution_free_fn)( + ANeuralNetworksExecution* execution); + +typedef int (*ANeuralNetworksExecution_setInput_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const void* buffer, size_t length); + +typedef int (*ANeuralNetworksExecution_setInputFromMemory_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length); + +typedef int (*ANeuralNetworksExecution_setOutput_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, void* buffer, size_t length); + +typedef int (*ANeuralNetworksExecution_setOutputFromMemory_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length); + +typedef int (*ANeuralNetworksExecution_startCompute_fn)( + ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event); + +typedef int (*ANeuralNetworksEvent_wait_fn)(ANeuralNetworksEvent* event); + +typedef void (*ANeuralNetworksEvent_free_fn)(ANeuralNetworksEvent* event); + +/** + * Creates a shared memory object from a file descriptor. + * + * The shared memory is backed by a file descriptor via mmap. + * See {@link ANeuralNetworksMemory} for a description on how to use + * this shared memory. + * + * @param size The requested size in bytes. + * Must not be larger than the file size. + * @param prot The desired memory protection for the mapping. + * It is either PROT_NONE or the bitwise OR of one or + * more of the following flags: PROT_READ, PROT_WRITE. + * @param fd The requested file descriptor. + * The file descriptor has to be mmap-able. The file + * descriptor will be duplicated. + * @param offset The offset to the beginning of the file of the area to map. + * The offset has to be aligned to a page size. + * @param memory The memory object to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if the request completed normally. + */ +inline int ANeuralNetworksMemory_createFromFd(size_t size, int protect, int fd, + size_t offset, + ANeuralNetworksMemory** memory) { + LOAD_FUNCTION(ANeuralNetworksMemory_createFromFd); + EXECUTE_FUNCTION_RETURN(size, protect, fd, offset, memory); +} + +/** + * Delete a memory object. + * + * Destroys the object used by the run time to keep track of the memory. + * This will free the underlying actual memory if no other code has open + * handles to this memory. + * + * @param memory The memory object to be freed. + */ +inline void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) { + LOAD_FUNCTION(ANeuralNetworksMemory_free); + EXECUTE_FUNCTION(memory); +} + +/** + * Create an empty {@link ANeuralNetworksModel}. + * + *

This only creates the object. Computation is performed once + * {@link ANeuralNetworksExecution_startCompute} is invoked. + * + * The model should be constructed with calls to + * {@link ANeuralNetworksModel_addOperation} and + * {@link ANeuralNetworksModel_addOperand} + * + *

{@link ANeuralNetworksModel_finish} should be called once the model + * has been fully constructed.

+ * + *

{@link ANeuralNetworksModel_free} should be called once the model + * is no longer needed.

+ * + * @param model The {@link ANeuralNetworksModel} to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_create(ANeuralNetworksModel** model) { + LOAD_FUNCTION(ANeuralNetworksModel_create); + EXECUTE_FUNCTION_RETURN(model); +} + +/** + * Destroy a model. + * + * The model need not have been finished by a call to + * {@link ANeuralNetworksModel_finish}. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be destroyed. Passing NULL is acceptable and + * results in no operation. + */ +inline void ANeuralNetworksModel_free(ANeuralNetworksModel* model) { + LOAD_FUNCTION(ANeuralNetworksModel_free); + EXECUTE_FUNCTION(model); +} + +/** + * Indicate that we have finished modifying a model. Required before + * calling {@link ANeuralNetworksCompilation_compile}. + * + * An application is responsible to make sure that no other thread uses + * the model at the same time. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be finished. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) { + LOAD_FUNCTION(ANeuralNetworksModel_finish); + EXECUTE_FUNCTION_RETURN(model); +} + +/** + * Add an operand to a model. + * + * The order in which the operands are added is important. The first one added + * to a model will have the index value 0, the second 1, etc. These indexes are + * used as operand identifiers in {@link ANeuralNetworksModel_addOperation}, + * {@link ANeuralNetworksExecution_setInput}, + * {@link ANeuralNetworksExecution_setInputFromMemory}, + * {@link ANeuralNetworksExecution_setOutput}, + * {@link ANeuralNetworksExecution_setOutputFromMemory} and + * {@link ANeuralNetworksExecution_setOperandValue}. + * + * To build a model that can accomodate inputs of various sizes, as you may want + * to do for a CNN, set the size of the dimensions that will vary at run time to + * 0. If you do so, provide the full dimensions when calling + * {@link ANeuralNetworksExecution_setInput} or {@link + * ANeuralNetworksExecution_setInputFromMemory}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param type The {@link ANeuralNetworksOperandType} that describes the shape + * of the operand. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_addOperand( + ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) { + LOAD_FUNCTION(ANeuralNetworksModel_addOperand); + EXECUTE_FUNCTION_RETURN(model, type); +} + +/** + * Sets an operand to a constant value. + * + * For scalar values, the content of buffer is copied into the model. + * + * For tensor values, a pointer to the buffer is stored within the model. + * The application is responsible for not changing the content of this region + * until all executions using this model have completed. As the data may + * be copied during processing, modifying the data after this call yields + * undefined results. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param buffer A pointer to the data to use. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, + int32_t index, + const void* buffer, + size_t length) { + LOAD_FUNCTION(ANeuralNetworksModel_setOperandValue); + EXECUTE_FUNCTION_RETURN(model, index, buffer, length); +} + +/** + * Sets an operand to a value stored in a memory object. + * + * The content of the memory is not copied. A reference to that memory is stored + * inside the model. The application is responsible for not changing the content + * of the memory region until all executions using this model have completed. + * As the data may be copied during processing, modifying the data after this + * call yields undefined results. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param buffer A pointer to the data to use. + * @param memory The memory containing the data. + * @param offset This specifies the location of the data within the memory. + * The offset is in bytes from the start of memory. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_setOperandValueFromMemory( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + LOAD_FUNCTION(ANeuralNetworksModel_setOperandValueFromMemory); + EXECUTE_FUNCTION_RETURN(model, index, memory, offset, length); +} + +/** + * Add an operation to a model. + * + * @param model The model to be modified. + * @param type The type of the operation. + * @param inputCount The number of entries in the inputs array. + * @param inputs An array of indexes identifying each operand. + * @param outputCount The number of entries in the outputs array. + * @param outputs An array of indexes identifying each operand. + * + * The operands specified by inputs and outputs must have been + * previously added by calls to {@link ANeuralNetworksModel_addOperand}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model, + ANeuralNetworksOperationType type, + uint32_t inputCount, + const uint32_t* inputs, + uint32_t outputCount, + const uint32_t* outputs) { + LOAD_FUNCTION(ANeuralNetworksModel_addOperation); + EXECUTE_FUNCTION_RETURN(model, type, inputCount, inputs, outputCount, + outputs); +} + +/** + * Specfifies which operands will be the model's inputs and outputs. + * + * An operand cannot be used for both input and output. Doing so will + * return an error. + * + * @param model The model to be modified. + * @param inputCount The number of entries in the inputs array. + * @param inputs An array of indexes identifying the input operands. + * @param outputCount The number of entries in the outputs array. + * @param outputs An array of indexes identifying the output operands. + * + * The operands specified by inputs and outputs must have been + * previously added by calls to {@link ANeuralNetworksModel_addOperand}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + */ +inline int ANeuralNetworksModel_identifyInputsAndOutputs( + ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, + uint32_t outputCount, const uint32_t* outputs) { + LOAD_FUNCTION(ANeuralNetworksModel_identifyInputsAndOutputs); + EXECUTE_FUNCTION_RETURN(model, inputCount, inputs, outputCount, outputs); +} + +/** + * Create a {@link ANeuralNetworksCompilation} to compile the given model. + * This only creates the object. Compilation is only performed once + * {@link ANeuralNetworksCompilation_start} is invoked. + * + *

The provided model must outlive the compilation.

+ * + * The model must already have been finished by a call to + * {@link ANeuralNetworksModel_finish}. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param model The {@link ANeuralNetworksModel} to be compiled. + * @param compilation The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the model is invalid. + */ +inline int ANeuralNetworksCompilation_create( + ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_create); + EXECUTE_FUNCTION_RETURN(model, compilation); +} + +/** + * Destroy a compilation. + * + *

If called on a compilation for which + * {@link ANeuralNetworksCompilation_start} has been called, the + * function will return immediately but will mark the compilation to be deleted + * once the compilation completes. The {@link ANeuralNetworksCompilation_wait} + * will return ERROR_DELETED. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be destroyed. Passing NULL is + * acceptable and results in no operation. + */ +inline void ANeuralNetworksCompilation_free( + ANeuralNetworksCompilation* compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_free); + EXECUTE_FUNCTION(compilation); +} + +/** + * Sets the execution preference. + * + *

Provides guidance to the runtime when trade-offs are possible.

+ * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be modified. + * @param preference Either {@link PREFER_LOW_POWER}, + * {@link PREFER_SINGLE_FAST_ANSWER}, or + * {@link PREFER_SUSTAINED_SPEED}. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksCompilation_setPreference( + ANeuralNetworksCompilation* compilation, int32_t preference) { + LOAD_FUNCTION(ANeuralNetworksCompilation_setPreference); + EXECUTE_FUNCTION_RETURN(compilation, preference); +} + +/** + * Waits until the compilation completes. + * + * More than one thread can wait on a compilation. When the compilation + * completes, all threads will be released. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @return ANEURALNETWORKS_NO_ERROR if the compilation completed normally. + */ +inline int ANeuralNetworksCompilation_finish( + ANeuralNetworksCompilation* compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_finish); + EXECUTE_FUNCTION_RETURN(compilation); +} +/** + * Create a {@link ANeuralNetworksExecution} to apply the given compilation. + * This only creates the object. Computation is only performed once + * {@link ANeuralNetworksExecution_startCompute} is invoked. + * + *

The provided compilation must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated. + * @param execution The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the compilation is invalid. + */ +inline int ANeuralNetworksExecution_create( + ANeuralNetworksCompilation* compilation, + ANeuralNetworksExecution** execution) { + LOAD_FUNCTION(ANeuralNetworksExecution_create); + EXECUTE_FUNCTION_RETURN(compilation, execution); +} + +/** + * Destroy an execution. + * + *

If called on an execution for which + * {@link ANeuralNetworksExecution_startCompute} has been called, the + * function will return immediately but will mark the execution to be deleted + * once the computation completes. The {link ANeuralNetworksExecution_wait} + * will return ANEURALNETWORKS_ERROR_DELETED. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be destroyed. Passing NULL is acceptable + * and results in no operation. + */ +inline void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) { + LOAD_FUNCTION(ANeuralNetworksExecution_free); + EXECUTE_FUNCTION(execution); +} + +/** + * Associate a user buffer with an input of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided buffer must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the input argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This should be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other properties of the type must be the same as + * specified in the model. If the type is the same as specified + * when the model was built, NULL can be passed. + * @param buffer The buffer containing the data. + * @param length The length in bytes of the buffer. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the input. + */ +inline int ANeuralNetworksExecution_setInput( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const void* buffer, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setInput); + EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length); +} + +/** + * Associate part of a memory object with an input of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided memory must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the input argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param memory The memory containing the data. + * @param offset This specifies the location of the data whithin the memory. + * The offset is in bytes from the start of memory. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the input. + */ +inline int ANeuralNetworksExecution_setInputFromMemory( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setInputFromMemory); + EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length); +} + +/** + * Associate a user buffer with an output of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided buffer must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the output argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param buffer The buffer where the data is to be written. + * @param length The length in bytes of the buffer. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the output. + */ +inline int ANeuralNetworksExecution_setOutput( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, void* buffer, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setOutput); + EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length); +} + +/** + * Associate part of a memory object with an output of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided memory must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the output argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param memory The memory where the data is to be stored. + * @param offset This specifies the location of the data whithin the memory. + * The offset is in bytes from the start of memory. + * @param length The length in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the output. + */ +inline int ANeuralNetworksExecution_setOutputFromMemory( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setOutputFromMemory); + EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length); +} + +/** + * Schedule evaluation of the execution. + * + *

Schedules evaluation of the execution. Once the model has been + * applied and the outputs are ready to be consumed, the execution will be + * signaled. Use {@link ANeuralNetworksExecution_wait} to wait for that signal. + *

+ * + * Multiple executions can be scheduled and evaluated concurrently, and + * compilations can be performed concurrently with executions. The runtime makes + * no guarantee on the ordering of the completion of compilations and + * executions. If it's important to the application, the application should + * enforce the ordering by using {@link ANeuralNetworksCompilation_wait} and + * {@link ANeuralNetworksExecution_wait}. + * + * ANeuralNetworksExecution_wait must be called to recuperate the resources used + * by the execution. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be scheduled and executed. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksExecution_startCompute( + ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) { + LOAD_FUNCTION(ANeuralNetworksExecution_startCompute); + EXECUTE_FUNCTION_RETURN(execution, event); +} + +/** + * Waits until the execution completes. + * + * More than one thread can wait on an event. When the execution completes, + * all threads will be released. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + */ +inline int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) { + LOAD_FUNCTION(ANeuralNetworksEvent_wait); + EXECUTE_FUNCTION_RETURN(event); +} + +/** + * Destroys the event. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + */ +inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) { + LOAD_FUNCTION(ANeuralNetworksEvent_free); + EXECUTE_FUNCTION(event); +} + +/**/ + +#endif // NN_API_SHIM_H0 diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc new file mode 100644 index 0000000000..6a199cc840 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -0,0 +1,386 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { + +// TODO(aselle): FATAL leaves resources hanging. +void FATAL(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fflush(stderr); + exit(1); +} + +// TODO(aselle): Change the error model to use status codes. +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +#define CHECK_NN(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +NNAPIAllocation::NNAPIAllocation(const char* filename, + ErrorReporter* error_reporter) + : MMAPAllocation(filename, error_reporter) { + if (mmapped_buffer_ != MAP_FAILED) + CHECK_NN(ANeuralNetworksMemory_createFromFd(buffer_size_bytes_, PROT_READ, + mmap_fd_, 0, &handle_)); +} + +NNAPIAllocation::~NNAPIAllocation() { + if (handle_) { + ANeuralNetworksMemory_free(handle_); + } +} + +NNAPIDelegate::~NNAPIDelegate() { + if (nn_model_) { + ANeuralNetworksModel_free(nn_model_); + nn_model_ = nullptr; + // TODO(aselle): Is this thread-safe and callable multiple times? + } + // ANeuralNetworksShutdown(); +} + +// Adds the tensors of the interpreter to the NN API model. +// Returns the number of operands added. +uint32_t addTensorOperands(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model) { + uint32_t next_id = 0; + for (size_t i = 0; i < interpreter->tensors_size(); i++) { + int32_t nn_type = 0; + float scale = 1.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = interpreter->tensor(i); + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + continue; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + default: + FATAL("Unsupported type."); + } + // TODO(aselle): Note, many of these are intermediate results. Do I need + // to ever specify these sizes. I am currently below doing setValue + // on all of them, but I shouldn't in the future. + // Answer(jeanluc): If all the operators can set the dimension correctly, + // you won't need to. + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + + // TODO(aselle): Based on Michael's suggestion, limiting this to read + // only memory + if (tensor->allocation_type == kTfLiteMmapRo) { + if (const NNAPIAllocation* alloc = dynamic_cast( + static_cast(tensor->allocation))) { + CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory( + nn_model, i, alloc->memory(), alloc->offset(tensor->data.raw), + tensor->bytes)); + } else { + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, i, tensor->data.raw, tensor->bytes)); + } + } + ++next_id; + } + return next_id; +} + +// Adds the operations and their parameters to the NN API model. +// 'next-id' is the operand ID of the next operand of the model. +void AddOpsAndParams(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model, uint32_t next_id) { + for (size_t i = 0; i < interpreter->nodes_size(); i++) { + const auto* node_and_registration = interpreter->node_and_registration(i); + const TfLiteNode& node = node_and_registration->first; + const TfLiteRegistration& registration = node_and_registration->second; + tflite::BuiltinOperator builtin = + static_cast(registration.builtin_code); + + // Add the parameters. + std::vector augmented_inputs( + node.inputs->data, node.inputs->data + node.inputs->size); + + auto add_scalar_int32 = [&nn_model, &augmented_inputs, + &next_id](int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value, + sizeof(int32_t))) + augmented_inputs.push_back(next_id++); + }; + + auto add_scalar_float32 = [&nn_model, &augmented_inputs, + &next_id](float value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_FLOAT32}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value, + sizeof(float))) + augmented_inputs.push_back(next_id++); + }; + + auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + + auto add_pooling_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->padding); + add_scalar_int32(builtin->stride_width); + add_scalar_int32(builtin->stride_height); + add_scalar_int32(builtin->filter_width); + add_scalar_int32(builtin->filter_height); + add_scalar_int32(builtin->activation); + }; + + auto add_convolution_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->padding); + add_scalar_int32(builtin->stride_width); + add_scalar_int32(builtin->stride_height); + add_scalar_int32(builtin->activation); + }; + + auto add_depthwise_conv_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->padding); + add_scalar_int32(builtin->stride_width); + add_scalar_int32(builtin->stride_height); + add_scalar_int32(builtin->depth_multiplier); + add_scalar_int32(builtin->activation); + }; + + auto add_fully_connected_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; + + auto add_concatenation_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->axis); + if (builtin->activation != kTfLiteActNone) { + FATAL("Concatenation does not support fused activation in NNAPI"); + } + }; + + auto add_softmax_params = [&add_scalar_float32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_float32(builtin->beta); + }; + +#if 0 + auto add_reshape_params = [&](void* data) { + auto builtin = reinterpret_cast(data); + uint32_t tensor_size_shape = builtin->num_dimensions; + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_INT32, + {static_cast(1), + reinterpret_cast(&tensor_size_shape)}, + 0, + 0}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, next_id, builtin->shape, + sizeof(int) * builtin->num_dimensions)); + augmented_inputs.push_back(next_id++); + }; +#endif + + ANeuralNetworksOperationType nn_op_type; + switch (builtin) { + case tflite::BuiltinOperator_ADD: + nn_op_type = ANEURALNETWORKS_ADD; + add_add_params(); + break; + case tflite::BuiltinOperator_AVERAGE_POOL_2D: + add_pooling_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D; + break; + case tflite::BuiltinOperator_MAX_POOL_2D: + add_pooling_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_MAX_POOL_2D; + break; + case tflite::BuiltinOperator_L2_POOL_2D: + add_pooling_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_L2_POOL_2D; + break; + case tflite::BuiltinOperator_CONV_2D: + add_convolution_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_CONV_2D; + break; + case tflite::BuiltinOperator_RELU: + nn_op_type = ANEURALNETWORKS_RELU; + break; + case tflite::BuiltinOperator_RELU6: + nn_op_type = ANEURALNETWORKS_RELU6; + break; + case tflite::BuiltinOperator_TANH: + nn_op_type = ANEURALNETWORKS_TANH; + break; + case tflite::BuiltinOperator_LOGISTIC: + nn_op_type = ANEURALNETWORKS_LOGISTIC; + break; + case tflite::BuiltinOperator_DEPTHWISE_CONV_2D: + add_depthwise_conv_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D; + break; + case tflite::BuiltinOperator_CONCATENATION: + add_concatenation_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_CONCATENATION; + break; + case tflite::BuiltinOperator_SOFTMAX: + add_softmax_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SOFTMAX; + break; + case tflite::BuiltinOperator_FULLY_CONNECTED: + add_fully_connected_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED; + break; + case tflite::BuiltinOperator_RESHAPE: + nn_op_type = ANEURALNETWORKS_RESHAPE; + // add_reshape_params(node.builtin_data); + break; + case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: + case tflite::BuiltinOperator_LSH_PROJECTION: + case tflite::BuiltinOperator_SVDF: + case tflite::BuiltinOperator_HASHTABLE_LOOKUP: + case tflite::BuiltinOperator_RNN: + case tflite::BuiltinOperator_EMBEDDING_LOOKUP: + case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: + case tflite::BuiltinOperator_LSTM: + case tflite::BuiltinOperator_L2_NORMALIZATION: + case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: + case tflite::BuiltinOperator_MUL: + case tflite::BuiltinOperator_RESIZE_BILINEAR: + case tflite::BuiltinOperator_CALL: + case tflite::BuiltinOperator_SKIP_GRAM: + case tflite::BuiltinOperator_RELU1: + case tflite::BuiltinOperator_SPACE_TO_DEPTH: + FATAL("Op code %d is currently not delegated to NNAPI", builtin); + nn_op_type = -1; // set to invalid + break; + case tflite::BuiltinOperator_CUSTOM: + FATAL("Custom operations are not supported when using NNAPI."); + nn_op_type = -1; // set to invalid + break; + } + + // Add the operation. + CHECK_NN(ANeuralNetworksModel_addOperation( + nn_model, nn_op_type, static_cast(augmented_inputs.size()), + augmented_inputs.data(), static_cast(node.outputs->size), + reinterpret_cast(node.outputs->data))); + } +} + +TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { + // TODO(aselle): This is not correct. need to handle resize invalidation. + if (nn_model_ && nn_compiled_model_) return kTfLiteOk; + + if (!nn_model_) { + CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); + + uint32_t next_id = addTensorOperands(interpreter, nn_model_); + AddOpsAndParams(interpreter, nn_model_, next_id); + CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_, static_cast(interpreter->inputs().size()), + reinterpret_cast(interpreter->inputs().data()), + static_cast(interpreter->outputs().size()), + reinterpret_cast(interpreter->outputs().data()))); + CHECK_NN(ANeuralNetworksModel_finish(nn_model_)); + } + if (!nn_compiled_model_) { + CHECK_NN(ANeuralNetworksCompilation_create(nn_model_, &nn_compiled_model_)); + CHECK_NN(ANeuralNetworksCompilation_finish(nn_compiled_model_)); + } + return kTfLiteOk; +} + +TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { + if (!nn_model_) { + TF_LITE_ENSURE_STATUS(BuildGraph(interpreter)); + } + + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(ANeuralNetworksExecution_create(nn_compiled_model_, &execution)); + + // Currently perform deep copy of input buffer + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input = interpreter->inputs()[i]; + // TODO(aselle): Is this what we want or do we want input instead? + // TODO(aselle): This should be called setInputValue maybe to be cons. + TfLiteTensor* tensor = interpreter->tensor(input); + CHECK_NN(ANeuralNetworksExecution_setInput( + execution, i, nullptr, tensor->data.raw, tensor->bytes)); + } + // Tell nn api where to place final data. + for (size_t i = 0; i < interpreter->outputs().size(); i++) { + int output = interpreter->outputs()[i]; + TfLiteTensor* tensor = interpreter->tensor(output); + CHECK_NN(ANeuralNetworksExecution_setOutput( + execution, i, nullptr, tensor->data.raw, tensor->bytes)); + } + // Currently use blocking compute. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + +#if 0 + printf("From the NN API:\n"); + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + printf(" %f", *p); + } + printf("\n"); + } +#endif + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h new file mode 100644 index 0000000000..f29aa9e18e --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +class ANeuralNetworsModel; + +namespace tflite { + +class NNAPIAllocation : public MMAPAllocation { + public: + NNAPIAllocation(const char* filename, ErrorReporter* error_reporter); + ~NNAPIAllocation(); + + size_t offset(const void* ptr) const { + auto signed_offset = reinterpret_cast(ptr) - + reinterpret_cast(mmapped_buffer_); + + return static_cast(signed_offset); + } + + ANeuralNetworksMemory* memory() const { return handle_; } + bool valid() const override { return handle_ != nullptr; } + + private: + mutable ANeuralNetworksMemory* handle_ = nullptr; +}; + +class NNAPIDelegate { + public: + ~NNAPIDelegate(); + + // Convert a tflite graph to NNAPI + TfLiteStatus BuildGraph(Interpreter* interpreter); + + // Run + TfLiteStatus Invoke(Interpreter* interpreter); + + private: + // The NN API model handle + ANeuralNetworksModel* nn_model_ = nullptr; + // The NN API compilation handle + ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc new file mode 100644 index 0000000000..1f762e6688 --- /dev/null +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/optional_debug_tools.h" + +namespace tflite { + +void PrintIntVector(const std::vector& v) { + for (const auto& it : v) { + printf(" %d", it); + } + printf("\n"); +} + +void PrintTfLiteIntVector(const TfLiteIntArray* v) { + if (!v) { + printf(" (null)"); + return; + } + for (int k = 0; k < v->size; k++) { + printf(" %d", v->data[k]); + } + printf("\n"); +} + +const char* TensorTypeName(TfLiteType type) { + switch (type) { + case kTfLiteNoType: + return "kTfLiteNoType"; + case kTfLiteFloat32: + return "kTfLiteFloat32"; + case kTfLiteInt32: + return "kTfLiteInt32"; + case kTfLiteUInt8: + return "kTfLiteUInt8"; + case kTfLiteInt64: + return "kTfLiteInt64"; + case kTfLiteString: + return "kTfLiteString"; + } + return "(invalid)"; +} + +const char* AllocTypeName(TfLiteAllocationType type) { + switch (type) { + case kTfLiteMemNone: + return "kTfLiteMemNone"; + case kTfLiteMmapRo: + return "kTfLiteMmapRo"; + case kTfLiteDynamic: + return "kTfLiteDynamic"; + case kTfLiteArenaRw: + return "kTfLiteArenaRw"; + case kTfLiteArenaRwPersistent: + return "kTfLiteArenaRwPersistent"; + } + return "(invalid)"; +} + +// Prints a dump of what tensors and what nodes are in the interpreter. +void PrintInterpreterState(Interpreter* interpreter) { + printf("Interpreter has %d tensors and %d nodes\n", + interpreter->tensors_size(), interpreter->nodes_size()); + printf("Inputs:"); + PrintIntVector(interpreter->inputs()); + printf("Outputs:"); + PrintIntVector(interpreter->outputs()); + printf("\n"); + for (int tensor_index = 0; tensor_index < interpreter->tensors_size(); + tensor_index++) { + TfLiteTensor* tensor = interpreter->tensor(tensor_index); + printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, + TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type), + tensor->bytes, float(tensor->bytes) / float(1 << 20)); + PrintTfLiteIntVector(tensor->dims); + printf("\n"); + } + + for (int node_index = 0; node_index < interpreter->nodes_size(); + node_index++) { + const std::pair* node_and_reg = + interpreter->node_and_registration(node_index); + const TfLiteNode& node = node_and_reg->first; + const TfLiteRegistration& reg = node_and_reg->second; + printf("Node %3d Operator Builtin Code %3d\n", node_index, + reg.builtin_code); + printf(" Inputs:"); + PrintTfLiteIntVector(node.inputs); + printf(" Outputs:"); + PrintTfLiteIntVector(node.outputs); + } +} + +// Prints a dump of what tensors and what nodes are in the interpreter. +TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); + +} // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h new file mode 100644 index 0000000000..54d4876095 --- /dev/null +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Optional debugging functionality. For small sized binaries, these are not +// needed. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ + +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +// Prints a dump of what tensors and what nodes are in the interpreter. +void PrintInterpreterState(Interpreter* interpreter); + +// Prints a dump of what tensors and what nodes are in the interpreter. +TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD new file mode 100644 index 0000000000..b4aa032ff8 --- /dev/null +++ b/tensorflow/contrib/lite/python/BUILD @@ -0,0 +1,46 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "lite", + srcs = ["lite.py"], + # data = [ + # "//tensorflow/contrib/lite/toco/python:toco_from_protos", + # ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/toco:model_flags_proto_py", + "//tensorflow/contrib/lite/toco:toco_flags_proto_py", + "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco", + "//tensorflow/python:platform", + ], +) + +py_test( + name = "lite_test", + srcs = ["lite_test.py"], + deps = [ + ":lite", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py new file mode 100644 index 0000000000..5e8edbb937 --- /dev/null +++ b/tensorflow/contrib/lite/python/lite.py @@ -0,0 +1,199 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite tooling helper functionality. + +EXPERIMENTAL: APIs here are unstable and likely to change without notice. + +@@toco_convert +@@toco_convert_protos + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import subprocess +import tempfile + +from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco.python.tensorflow_wrap_toco import TocoConvert as _toco_convert_protos +from tensorflow.python.framework import dtypes as _dtypes +# from tensorflow.python.platform import +# resource_loader as _resource_loader + +# Enum types from the protobuf promoted to the API +FLOAT = _toco_flags_pb2.FLOAT +INT32 = _toco_flags_pb2.INT32 +INT64 = _toco_flags_pb2.INT64 +STRING = _toco_flags_pb2.STRING +QUANTIZED_UINT8 = _toco_flags_pb2.QUANTIZED_UINT8 +TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF +TFLITE = _toco_flags_pb2.TFLITE +GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT + +# Currently the default mode of operation is to shell to another python process +# to protect against crashes. +EXPERIMENTAL_USE_TOCO_API_DIRECTLY = True + +# Find the toco_from_protos binary using the resource loader if using from +# bazel, otherwise we are in a pip where console_scripts already has +# the toco_from_protos tool. +# toco_from_proto_bin = _resource_loader.get_path_to_datafile( +# "../toco/python/toco_from_protos") +# if not os.path.exists(toco_from_proto_bin): +# toco_from_proto_bin = "toco_from_protos" + + +def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): + """Convert `input_data_str` according to model and toco parameters. + + Unless you know what you are doing consider using + the more friendly @{tf.contrib.lite.toco_convert}}. + + Args: + model_flags_str: Serialized proto describing model properties, see + `toco/model_flags.proto`. + toco_flags_str: Serialized proto describing conversion properties, see + `toco/toco_flags.proto`. + input_data_str: Input data in serialized form (e.g. a graphdef is common) + Returns: + Converted model in serialized form (e.g. a TFLITE model is common). + Raises: + RuntimeError: When conversion fails, an exception is raised with the error + message embedded. + """ + # TODO(aselle): When toco does not use fatal errors for failure, we can + # switch this on. + if EXPERIMENTAL_USE_TOCO_API_DIRECTLY: + return _toco_convert_protos(model_flags_str, toco_flags_str, input_data_str) + + # with tempfile.NamedTemporaryFile() as fp_toco, \ + # tempfile.NamedTemporaryFile() as fp_model, \ + # tempfile.NamedTemporaryFile() as fp_input, \ + # tempfile.NamedTemporaryFile() as fp_output: + # fp_model.write(model_flags_str) + # fp_toco.write(toco_flags_str) + # fp_input.write(input_data_str) + # fp_model.flush() + # fp_toco.flush() + # fp_input.flush() + + # cmd = [ + # toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name, + # fp_output.name + # ] + # cmdline = " ".join(cmd) + # proc = subprocess.Popen( + # cmdline, + # shell=True, + # stdout=subprocess.PIPE, + # stderr=subprocess.STDOUT, + # close_fds=True) + # stdout, stderr = proc.communicate() + # exitcode = proc.returncode + # if exitcode == 0: + # stuff = fp_output.read() + # return stuff + # else: + # raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" % + # (stdout, stderr)) + + +def _tensor_name(x): + return x.name.split(":")[0] + + +def toco_convert(input_data, + input_tensors, + output_tensors, + inference_type=FLOAT, + input_format=TENSORFLOW_GRAPHDEF, + output_format=TFLITE, + quantized_input_stats=None, + drop_control_dependency=True): + """Convert a model using TOCO from `input_format` to `output_format`. + + Typically this is to convert from TensorFlow GraphDef to TFLite, in which + case the default `input_format` and `output_format` are sufficient. + + Args: + input_data: Input data (i.e. often `sess.graph_def`). + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. + input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). + output_format: Type of data to write (currently must be TFLITE or + GRAPHVIZ_DOT) + quantized_input_stats: For each member of input_tensors the mean and + std deviation of training data. Only needed if `inference_type` is + `QUANTIZED_UINT8`. + drop_control_dependency: Drops control dependencies silently. This is due + to tf lite not supporting control dependencies. + + Returns: + The converted data. For example if tflite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + ValueError: If the input tensor type is unknown + RuntimeError: If TOCO fails to convert (in which case the runtime error's + error text will contain the TOCO error log) + """ + toco = _toco_flags_pb2.TocoFlags() + toco.input_format = input_format + toco.output_format = output_format + model = _model_flags_pb2.ModelFlags() + model.drop_control_dependency = drop_control_dependency + toco.inference_type = inference_type + for idx, input_tensor in enumerate(input_tensors): + if input_tensor.dtype == _dtypes.float32: + tflite_input_type = FLOAT + elif input_tensor.dtype == _dtypes.int32: + tflite_input_type = INT32 + elif input_tensor.dtype == _dtypes.int64: + tflite_input_type = INT64 + # TODO(aselle): Insert strings when they are available + else: + raise ValueError("Tensors %s not known type %r" % (input_tensor.name, + input_tensor.dtype)) + + input_array = model.input_arrays.add() + + if inference_type == QUANTIZED_UINT8: + if tflite_input_type == FLOAT: + tflite_input_type = QUANTIZED_UINT8 + input_array.mean, input_array.std = quantized_input_stats[idx] + + input_array.name = _tensor_name(input_tensor) + input_array.shape.extend(map(int, input_tensor.get_shape())) + toco.input_types.append(tflite_input_type) + + for output_tensor in output_tensors: + model.output_arrays.append(_tensor_name(output_tensor)) + + data = toco_convert_protos(model.SerializeToString(), + toco.SerializeToString(), + input_data.SerializeToString()) + return data + + +# remove_undocumented(__name__) + +del os +del subprocess +del tempfile diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py new file mode 100644 index 0000000000..da360aeb34 --- /dev/null +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Python Interface: Sanity check.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.lite.python import lite +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class LiteTest(test_util.TensorFlowTestCase): + + def testBasic(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], + dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + # Try running on valid graph + result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) + self.assertTrue(result) + # TODO(aselle): remove tests that fail. + # Try running on identity graph (known fail) + # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): + # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD new file mode 100644 index 0000000000..3e04d6f34f --- /dev/null +++ b/tensorflow/contrib/lite/schema/BUILD @@ -0,0 +1,82 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_binary( + name = "upgrade_schema", + srcs = [ + "upgrade_schema.py", + ], + data = [ + "schema_v0.fbs", + "schema_v1.fbs", + "schema_v2.fbs", + "schema_v3.fbs", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + "@flatbuffers//:flatc", + ], +) + +py_test( + name = "upgrade_schema_test", + size = "small", + srcs = ["upgrade_schema_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":upgrade_schema", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +exports_files([ + "schema_v0.fbs", + "schema_v1.fbs", + "schema_v2.fbs", + "schema_v3.fbs", +]) + +load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library") + +# Generic schema for inference on device. +flatbuffer_cc_library( + name = "schema_fbs", + srcs = ["schema.fbs"], +) + +# Schema test to make sure we don't introduce backward incompatible changes +# to schemas. +cc_test( + name = "flatbuffer_compatibility_test", + size = "small", + srcs = ["flatbuffer_compatibility_test.cc"], + data = [ + "schema.fbs", + "schema_v3.fbs", + ], + deps = [ + "//tensorflow/core:lib_platform", + "@com_google_googletest//:gtest", + "@flatbuffers//:flatc_library", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc new file mode 100644 index 0000000000..17ee0af8dd --- /dev/null +++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "third_party/flatbuffers/include/flatbuffers/flatc.h" +#include "tensorflow/core/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define TFLITE_TF_PREFIX "third_party/tensorflow/" +#else +#define TFLITE_TF_PREFIX "tensorflow/" +#endif +/// Load filename `name` +bool LoadFileRaw(const char *name, std::string *buf) { + std::ifstream fp(name, std::ios::binary); + if (!fp) { + fprintf(stderr, "Failed to read '%s'\n", name); + return false; + } + std::string s((std::istreambuf_iterator(fp)), + std::istreambuf_iterator()); + if (s.empty()) { + fprintf(stderr, "Read '%s' resulted in empty\n", name); + return false; + } + *buf = s; + return true; +} + +bool ParseFile(flatbuffers::Parser *parser, const std::string &filename, + const std::string &contents) { + std::vector include_directories; + auto local_include_directory = flatbuffers::StripFileName(filename); + include_directories.push_back(local_include_directory.c_str()); + include_directories.push_back(nullptr); + if (!parser->Parse(contents.c_str(), include_directories.data(), + filename.c_str())) { + fprintf(stderr, "Failed to parse flatbuffer schema '%s'\n", + contents.c_str()); + return false; + } + return true; +} + +// Checks to make sure current schema in current code does not cause an +// incompatibility. +TEST(SchemaTest, TestCompatibility) { + // Read file contents of schemas into strings + // TODO(aselle): Need a reliable way to load files. + std::string base_contents, current_contents; + const char *base_filename = + TFLITE_TF_PREFIX "contrib/lite/schema/schema_v3.fbs"; + const char *current_filename = + TFLITE_TF_PREFIX "contrib/lite/schema/schema.fbs"; + + ASSERT_TRUE(LoadFileRaw(base_filename, &base_contents)); + ASSERT_TRUE(LoadFileRaw(current_filename, ¤t_contents)); + // Parse the schemas + flatbuffers::Parser base_parser, current_parser; + std::vector include_directories; + ASSERT_TRUE(ParseFile(&base_parser, base_filename, base_contents)); + ASSERT_TRUE(ParseFile(¤t_parser, current_filename, current_contents)); + // Check that the schemas conform and fail if they don't + auto err = current_parser.ConformTo(base_parser); + if (!err.empty()) { + fprintf(stderr, + "Schemas don't conform:\n%s\n" + "In other words some change you made means that new parsers can't" + "parse old files.\n", + err.c_str()); + FAIL(); + } +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs new file mode 100644 index 0000000000..ddb2ab792c --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -0,0 +1,346 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existant empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + // DEPTH_TO_SPACE = 5, + // DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + // FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + RELU1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input and output tensors are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. +table Buffer { + data:[ubyte]; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model + buffers:[Buffer]; + +} + +root_type Model; + diff --git a/tensorflow/contrib/lite/schema/schema_v0.fbs b/tensorflow/contrib/lite/schema/schema_v0.fbs new file mode 100644 index 0000000000..852ea988f3 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v0.fbs @@ -0,0 +1,247 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // The data_buffer is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*4*3 + j*3 + k]. + data_buffer:[ubyte]; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + CUSTOM = 0, + CONVOLUTION = 1, + DEPTHWISE_CONVOLUTION = 2, + CONCAT_EMBEDDINGS = 3, + LSH_PROJECTION = 4, + TANH = 5, + RELU = 6, + AVERAGE_POOL = 7, + MAX_POOL = 8, + L2_POOL = 9, + SIGMOID = 10, + SVDF = 11, + BasicRNN = 12, + RELU6 = 13, + EMBEDDING_LOOKUP = 14, + FULLY_CONNECTED = 15, + HASHTABLE_LOOKUP = 16, + SOFTMAX = 17, + CONCATENATION = 18, + LSTM = 19, + ADD = 20, + L2NORM = 21, + LOCAL_RESPONSE_NORM = 22, + RESIZE_BILINEAR = 23, +} + +// Options for the builtin operators. +union BuiltinOptions { + ConvolutionOptions, + DepthwiseConvolutionOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + PoolOptions, + SVDFOptions, + BasicRNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormOptions, + LSTMOptions, + ResizeBilinearOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table ConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table PoolOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow BasicRNNCell. +table BasicRNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:int; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table Model { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All operators, in execution order. + operators:[Operator]; +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/schema_v1.fbs b/tensorflow/contrib/lite/schema/schema_v1.fbs new file mode 100644 index 0000000000..06cd9408ed --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v1.fbs @@ -0,0 +1,295 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. + +namespace tflite; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // The data_buffer is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + data_buffer:[ubyte]; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + CUSTOM = 0, + CONVOLUTION = 1, + DEPTHWISE_CONVOLUTION = 2, + CONCAT_EMBEDDINGS = 3, + LSH_PROJECTION = 4, + TANH = 5, + RELU = 6, + AVERAGE_POOL = 7, + MAX_POOL = 8, + L2_POOL = 9, + SIGMOID = 10, + SVDF = 11, + BasicRNN = 12, + RELU6 = 13, + EMBEDDING_LOOKUP = 14, + FULLY_CONNECTED = 15, + HASHTABLE_LOOKUP = 16, + SOFTMAX = 17, + CONCATENATION = 18, + LSTM = 19, + ADD = 20, + L2NORM = 21, + LOCAL_RESPONSE_NORM = 22, + RESIZE_BILINEAR = 23, + CALL = 24, + RESHAPE = 25, + SKIP_GRAM = 26, + SPACE_TO_DEPTH = 27, +} + +// Options for the builtin operators. +union BuiltinOptions { + ConvolutionOptions, + DepthwiseConvolutionOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + PoolOptions, + SVDFOptions, + BasicRNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table ConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table PoolOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow BasicRNNCell. +table BasicRNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:int; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:int; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +table Model { + // Version of the schema. + version:int; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/schema_v2.fbs b/tensorflow/contrib/lite/schema/schema_v2.fbs new file mode 100644 index 0000000000..96731c8aae --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v2.fbs @@ -0,0 +1,303 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. + +namespace tflite; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // The data_buffer is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + data_buffer:[ubyte]; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + // DEPTH_TO_SPACE = 5, + // DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + // FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + // MUL = 18, + RELU = 19, + // RELU1=20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:int; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:int; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +table Model { + // Version of the schema. + version:int; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/schema_v3.fbs b/tensorflow/contrib/lite/schema/schema_v3.fbs new file mode 100644 index 0000000000..cedefe08f3 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v3.fbs @@ -0,0 +1,326 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version (4). +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existant empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + // DEPTH_TO_SPACE = 5, + // DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + // FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + // MUL = 18, + RELU = 19, + // RELU1=20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. +table Buffer { + data:[ubyte]; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // NOTE: It is required that the first entry in here is always an empty + // buffer. This is so that the default buffer index of zero in Tensor + // will always refer to a valid empty buffer. + buffers:[Buffer]; + +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/contrib/lite/schema/upgrade_schema.py new file mode 100644 index 0000000000..320c7138d2 --- /dev/null +++ b/tensorflow/contrib/lite/schema/upgrade_schema.py @@ -0,0 +1,341 @@ +# ============================================================================== +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Upgrade script to move from pre-release schema to new schema. + +Usage examples: + +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import json +import os +import shutil +import subprocess +import sys +import tempfile + +import tensorflow as tf +from tensorflow.python.platform import resource_loader + +parser = argparse.ArgumentParser( + description="Script to move TFLite models from pre-release schema to" + " new schema.") +parser.add_argument( + "input", + type=str, + help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.") +parser.add_argument( + "output", + type=str, + help="Output json or bin TensorFlow lite model compliant with" + "the new schema. Extension must be `.json`, `.bin` or `.tflite`.") + + +# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles. +@contextlib.contextmanager +def TemporaryDirectoryResource(): + temporary = tempfile.mkdtemp() + try: + yield temporary + finally: + shutil.rmtree(temporary) + + +class Converter(object): + """Converts TensorFlow flatbuffer models from old to new version of schema. + + This can convert between any version to the latest version. It uses + an incremental upgrade strategy to go from version to version. + + Usage: + converter = Converter() + converter.Convert("a.tflite", "a.json") + converter.Convert("b.json", "b.tflite") + """ + + def __init__(self): + # TODO(aselle): make this work in the open source version with better + # path. + self._flatc_path = resource_loader.get_path_to_datafile( + "../../../../flatbuffers/flatc") + + def FindSchema(base_name): + return resource_loader.get_path_to_datafile("%s" % base_name) + + # Supported schemas for upgrade. + self._schemas = [ + (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1), + (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2), + (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3), + (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design. + ] + # Ensure schemas are sorted, and extract latest version and upgrade + # dispatch function table. + self._schemas.sort() + self._new_version, self._new_schema = self._schemas[-1][:2] + self._upgrade_dispatch = dict( + (version, dispatch) + for version, unused1, unused2, dispatch in self._schemas) + + def _Read(self, input_file, schema, raw_binary=False): + """Read a tflite model assuming the given flatbuffer schema. + + If `input_file` is in bin, then we must use flatc to convert the schema + from binary to json. + + Args: + input_file: a binary (flatbuffer) or json file to read from. Extension + must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or + FlatBuffer JSON. + schema: which schema to use for reading + raw_binary: whether to assume raw_binary (versions previous to v3) + that lacked file_identifier require this. + + Raises: + RuntimeError: When flatc cannot be invoked. + ValueError: When the extension is not json or bin. + + Returns: + A dictionary representing the read tflite model. + """ + raw_binary = ["--raw-binary"] if raw_binary else [] + with TemporaryDirectoryResource() as tempdir: + basename = os.path.basename(input_file) + basename_no_extension, extension = os.path.splitext(basename) + if extension in [".bin", ".tflite"]: + # Convert to json using flatc + returncode = subprocess.call([ + self._flatc_path, + "-t", + "--strict-json", + "--defaults-json", + ] + raw_binary + ["-o", tempdir, schema, "--", input_file]) + if returncode != 0: + raise RuntimeError("flatc failed to convert from binary to json.") + json_file = os.path.join(tempdir, basename_no_extension + ".json") + if not os.path.exists(json_file): + raise RuntimeError("Could not find %r" % json_file) + elif extension == ".json": + json_file = input_file + else: + raise ValueError("Invalid extension on input file %r" % input_file) + return json.load(open(json_file)) + + def _Write(self, data, output_file): + """Output a json or bin version of the flatbuffer model. + + Args: + data: Dict representing the TensorFlow Lite model to write. + output_file: filename to write the converted flatbuffer to. (json, + tflite, or bin extension is required). + Raises: + ValueError: When the extension is not json or bin + RuntimeError: When flatc fails to convert json data to binary. + """ + _, extension = os.path.splitext(output_file) + with TemporaryDirectoryResource() as tempdir: + if extension == ".json": + json.dump(data, open(output_file, "w"), sort_keys=True, indent=2) + elif extension in [".tflite", ".bin"]: + input_json = os.path.join(tempdir, "temp.json") + with open(input_json, "w") as fp: + json.dump(data, fp, sort_keys=True, indent=2) + returncode = subprocess.call([ + self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o", + tempdir, self._new_schema, input_json + ]) + if returncode != 0: + raise RuntimeError("flatc failed to convert upgraded json to binary.") + + shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file) + else: + raise ValueError("Invalid extension on output file %r" % output_file) + + def _Upgrade0To1(self, data): + """Upgrade data from Version 0 to Version 1. + + Changes: Added subgraphs (which contains a subset of formally global + entries). + + Args: + data: Dictionary representing the TensorFlow lite data to be upgraded. + This will be modified in-place to be an upgraded version. + """ + subgraph = {} + for key_to_promote in ["tensors", "operators", "inputs", "outputs"]: + subgraph[key_to_promote] = data[key_to_promote] + del data[key_to_promote] + data["subgraphs"] = [subgraph] + + def _Upgrade1To2(self, data): + """Upgrade data from Version 1 to Version 2. + + Changes: Rename operators to Conform to NN API. + + Args: + data: Dictionary representing the TensorFlow lite data to be upgraded. + This will be modified in-place to be an upgraded version. + Raises: + ValueError: Throws when model builtins are numeric rather than symbols. + """ + + def RemapOperator(opcode_name): + """Go from old schema op name to new schema op name. + + Args: + opcode_name: String representing the ops (see :schema.fbs). + Returns: + Converted opcode_name from V1 to V2. + """ + old_name_to_new_name = { + "CONVOLUTION": "CONV_2D", + "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D", + "AVERAGE_POOL": "AVERAGE_POOL_2D", + "MAX_POOL": "MAX_POOL_2D", + "L2_POOL": "L2_POOL_2D", + "SIGMOID": "LOGISTIC", + "L2NORM": "L2_NORMALIZATION", + "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION", + "Basic_RNN": "RNN", + } + + return (old_name_to_new_name[opcode_name] + if opcode_name in old_name_to_new_name else opcode_name) + + def RemapOperatorType(operator_type): + """Remap operator structs from old names to new names. + + Args: + operator_type: String representing the builtin operator data type + string. + (see :schema.fbs). + Returns: + Upgraded builtin operator data type as a string. + """ + old_to_new = { + "PoolOptions": "Pool2DOptions", + "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions", + "ConvolutionOptions": "Conv2DOptions", + "LocalResponseNormOptions": "LocalResponseNormalizationOptions", + "BasicRNNOptions": "RNNOptions", + } + return (old_to_new[operator_type] + if operator_type in old_to_new else operator_type) + + for subgraph in data["subgraphs"]: + for ops in subgraph["operators"]: + ops["builtin_options_type"] = RemapOperatorType( + ops["builtin_options_type"]) + + # Upgrade the operator codes + for operator_code in data["operator_codes"]: + if not isinstance(operator_code["builtin_code"], unicode): + raise ValueError("builtin_code %r is non-string. this usually means" + "your model has consistency problems." % + (operator_code["builtin_code"])) + operator_code["builtin_code"] = (RemapOperator( + operator_code["builtin_code"])) + + def _Upgrade2To3(self, data): + """Upgrade data from Version 2 to Version 3. + + Changed actual read-only tensor data to be in a buffers table instead + of inline with the tensor. + + Args: + data: Dictionary representing the TensorFlow lite data to be upgraded. + This will be modified in-place to be an upgraded version. + """ + buffers = [{"data": []}] # Start with 1 empty buffer + for subgraph in data["subgraphs"]: + if "tensors" not in subgraph: + continue + for tensor in subgraph["tensors"]: + if "data_buffer" not in tensor: + tensor["buffer"] = 0 + else: + if tensor["data_buffer"]: + tensor[u"buffer"] = len(buffers) + buffers.append({"data": tensor["data_buffer"]}) + else: + tensor["buffer"] = 0 + del tensor["data_buffer"] + data["buffers"] = buffers + + def _PerformUpgrade(self, data): + """Manipulate the `data` (parsed JSON) based on changes in format. + + This incrementally will upgrade from version to version within data. + + Args: + data: Dictionary representing the TensorFlow data. This will be upgraded + in place. + """ + while data["version"] < self._new_version: + self._upgrade_dispatch[data["version"]](data) + data["version"] += 1 + + def Convert(self, input_file, output_file): + """Perform schema conversion from input_file to output_file. + + Args: + input_file: Filename of TensorFlow Lite data to convert from. Must + be `.json` or `.bin` extension files for JSON or Binary forms of + the TensorFlow FlatBuffer schema. + output_file: Filename to write to. Extension also must be `.json` + or `.bin`. + + Raises: + RuntimeError: Generated when none of the upgrader supported schemas + matche the `input_file` data. + """ + # Read data in each schema (since they are incompatible). Version is + # always present. Use the read data that matches the version of the + # schema. + for version, schema, raw_binary, _ in self._schemas: + try: + data_candidate = self._Read(input_file, schema, raw_binary) + except RuntimeError: + continue # Skip and hope another schema works + if "version" not in data_candidate: # Assume version 1 if not present. + data_candidate["version"] = 1 + elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild. + data_candidate["version"] = 1 + + if data_candidate["version"] == version: + self._PerformUpgrade(data_candidate) + self._Write(data_candidate, output_file) + return + raise RuntimeError("No schema that the converter understands worked with " + "the data file you provided.") + + +def main(argv): + del argv + Converter().Convert(FLAGS.input, FLAGS.output) + + +if __name__ == "__main__": + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/lite/schema/upgrade_schema_test.py b/tensorflow/contrib/lite/schema/upgrade_schema_test.py new file mode 100644 index 0000000000..475cdb9d8b --- /dev/null +++ b/tensorflow/contrib/lite/schema/upgrade_schema_test.py @@ -0,0 +1,317 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testing for updating TensorFlow lite schema.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import tempfile +from tensorflow.contrib.lite.schema import upgrade_schema as upgrade_schema_lib +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib + +EMPTY_TEST_SCHEMA_V1 = { + "version": 1, + "operator_codes": [], + "subgraphs": [], +} + +EMPTY_TEST_SCHEMA_V3 = { + "version": 3, + "operator_codes": [], + "subgraphs": [], + "buffers": [{ + "data": [] + }] +} + +TEST_SCHEMA_V0 = { + "operator_codes": [], + "tensors": [], + "inputs": [], + "outputs": [], + "operators": [], + "version": 0 +} + +TEST_SCHEMA_V3 = { + "operator_codes": [], + "buffers": [{ + "data": [] + }], + "subgraphs": [{ + "tensors": [], + "inputs": [], + "outputs": [], + "operators": [], + }], + "version": + 3 +} + +FULL_TEST_SCHEMA_V1 = { + "version": + 1, + "operator_codes": [ + { + "builtin_code": "CONVOLUTION" + }, + { + "builtin_code": "DEPTHWISE_CONVOLUTION" + }, + { + "builtin_code": "AVERAGE_POOL" + }, + { + "builtin_code": "MAX_POOL" + }, + { + "builtin_code": "L2_POOL" + }, + { + "builtin_code": "SIGMOID" + }, + { + "builtin_code": "L2NORM" + }, + { + "builtin_code": "LOCAL_RESPONSE_NORM" + }, + { + "builtin_code": "ADD" + }, + { + "builtin_code": "Basic_RNN" + }, + ], + "subgraphs": [{ + "operators": [ + { + "builtin_options_type": "PoolOptions" + }, + { + "builtin_options_type": "DepthwiseConvolutionOptions" + }, + { + "builtin_options_type": "ConvolutionOptions" + }, + { + "builtin_options_type": "LocalResponseNormOptions" + }, + { + "builtin_options_type": "BasicRNNOptions" + }, + ], + }], + "description": + "", +} + +FULL_TEST_SCHEMA_V3 = { + "version": + 3, + "operator_codes": [ + { + "builtin_code": "CONV_2D" + }, + { + "builtin_code": "DEPTHWISE_CONV_2D" + }, + { + "builtin_code": "AVERAGE_POOL_2D" + }, + { + "builtin_code": "MAX_POOL_2D" + }, + { + "builtin_code": "L2_POOL_2D" + }, + { + "builtin_code": "LOGISTIC" + }, + { + "builtin_code": "L2_NORMALIZATION" + }, + { + "builtin_code": "LOCAL_RESPONSE_NORMALIZATION" + }, + { + "builtin_code": "ADD" + }, + { + "builtin_code": "RNN" + }, + ], + "subgraphs": [{ + "operators": [ + { + "builtin_options_type": "Pool2DOptions" + }, + { + "builtin_options_type": "DepthwiseConv2DOptions" + }, + { + "builtin_options_type": "Conv2DOptions" + }, + { + "builtin_options_type": "LocalResponseNormalizationOptions" + }, + { + "builtin_options_type": "RNNOptions" + }, + ], + }], + "description": + "", + "buffers": [{ + "data": [] + }] +} + +BUFFER_TEST_V2 = { + "operator_codes": [], + "buffers": [], + "subgraphs": [{ + "tensors": [ + { + "data_buffer": [1, 2, 3, 4] + }, + { + "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8] + }, + { + "data_buffer": [] + }, + ], + "inputs": [], + "outputs": [], + "operators": [], + }], + "version": + 2 +} + +BUFFER_TEST_V3 = { + "operator_codes": [], + "subgraphs": [{ + "tensors": [ + { + "buffer": 1 + }, + { + "buffer": 2 + }, + { + "buffer": 0 + }, + ], + "inputs": [], + "outputs": [], + "operators": [], + }], + "buffers": [ + { + "data": [] + }, + { + "data": [1, 2, 3, 4] + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8] + }, + ], + "version": + 3 +} + + +def JsonDumpAndFlush(data, fp): + """Write the dictionary `data` to a JSON file `fp` (and flush). + + Args: + data: in a dictionary that is JSON serializable. + fp: File-like object + """ + json.dump(data, fp) + fp.flush() + + +class TestSchemaUpgrade(test_util.TensorFlowTestCase): + + def testNonExistantFile(self): + converter = upgrade_schema_lib.Converter() + non_existent = tempfile.mktemp(suffix=".json") + with self.assertRaisesRegexp(IOError, "No such file or directory"): + converter.Convert(non_existent, non_existent) + + def testInvalidExtension(self): + converter = upgrade_schema_lib.Converter() + invalid_extension = tempfile.mktemp(suffix=".foo") + with self.assertRaisesRegexp(ValueError, "Invalid extension on input"): + converter.Convert(invalid_extension, invalid_extension) + with tempfile.NamedTemporaryFile(suffix=".json") as in_json: + JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json) + with self.assertRaisesRegexp(ValueError, "Invalid extension on output"): + converter.Convert(in_json.name, invalid_extension) + + def CheckConversion(self, data_old, data_expected): + """Given a data dictionary, test upgrading to current version. + + Args: + data_old: TFLite model as a dictionary (arbitrary version). + data_expected: TFLite model as a dictionary (upgraded). + """ + converter = upgrade_schema_lib.Converter() + with tempfile.NamedTemporaryFile(suffix=".json") as in_json, \ + tempfile.NamedTemporaryFile(suffix=".json") as out_json, \ + tempfile.NamedTemporaryFile(suffix=".bin") as out_bin, \ + tempfile.NamedTemporaryFile(suffix=".tflite") as out_tflite: + JsonDumpAndFlush(data_old, in_json) + # Test JSON output + converter.Convert(in_json.name, out_json.name) + # Test binary output + # Convert to .tflite and then to .bin and check if binary is equal + converter.Convert(in_json.name, out_tflite.name) + converter.Convert(out_tflite.name, out_bin.name) + self.assertEqual(open(out_bin.name).read(), open(out_tflite.name).read()) + # Test that conversion actually produced successful new json. + converted_schema = json.load(out_json) + self.assertEqual(converted_schema, data_expected) + + def testAlreadyUpgraded(self): + """A file already at version 3 should stay at version 3.""" + self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3) + self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3) + self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3) + + # Disable this while we have incorrectly versioned structures around. + # def testV0Upgrade_IntroducesSubgraphs(self): + # """V0 did not have subgraphs; check to make sure they get introduced.""" + # self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3) + + def testV1Upgrade_RenameOps(self): + """V1 had many different names for ops; check to make sure they rename.""" + self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3) + self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3) + + def testV2Upgrade_CreateBuffers(self): + """V2 did not have buffers; check to make sure they are created.""" + self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3) + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc new file mode 100644 index 0000000000..4aab244989 --- /dev/null +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -0,0 +1,136 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +#include +#include +#include + +namespace { + +template +T AlignTo(size_t alignment, T offset) { + return offset % alignment == 0 ? offset + : offset + (alignment - offset % alignment); +} + +} // namespace + +namespace tflite { + +TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, + size_t alignment, size_t size, + ArenaAlloc* new_alloc) { + TF_LITE_ENSURE(context, alignment < arena_alignment_); + + size_t current_top = 0; + + if (!allocs_.empty()) { + auto last = allocs_.rbegin(); + current_top = last->offset + last->size; + } + + // If we don't find a better gap just allocate at the end of the buffer. + size_t best_offset = AlignTo(alignment, current_top); + size_t best_offset_fit = std::numeric_limits::max(); + auto best_insertion_it = allocs_.end(); + + // Go through the sorted allocs and look at the gaps between them. + size_t current_offset = 0; + for (auto it = allocs_.begin(); it != allocs_.end(); ++it) { + size_t aligned_current_offset = AlignTo(alignment, current_offset); + // If we found a gap larger than required size, and smaller than previous + // best fit, take it. + if (aligned_current_offset + size <= it->offset && + it->offset - current_offset < best_offset_fit) { + best_offset = aligned_current_offset; + best_offset_fit = it->offset - current_offset; + best_insertion_it = it; + } + current_offset = it->offset + it->size; + } + + // Update the required buffer size. + high_water_mark_ = std::max(high_water_mark_, best_offset + size); + + new_alloc->offset = best_offset; + new_alloc->size = size; + allocs_.insert(best_insertion_it, *new_alloc); + + return kTfLiteOk; +} + +TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context, + const ArenaAlloc& alloc) { + int erased_allocs_count = 0; + auto it = allocs_.begin(); + while (it != allocs_.end()) { + if (it->offset == alloc.offset) { + TF_LITE_ENSURE_EQ(context, it->size, alloc.size); + erased_allocs_count++; + it = allocs_.erase(it); + } else { + ++it; + } + } + TF_LITE_ENSURE_EQ(context, erased_allocs_count, 1); + return kTfLiteOk; +} + +TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) { + size_t required_size = RequiredBufferSize(); + if (required_size > underlying_buffer_size_) { + char* new_alloc = new char[required_size]; + char* new_underlying_buffer_aligned_ptr = reinterpret_cast( + AlignTo(arena_alignment_, reinterpret_cast(new_alloc))); + + // If the arena had been previously allocated, copy over the old memory. + // Since Alloc pointers are offset based, they will remain valid in the new + // memory block. + if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) { + size_t copy_amount = std::min( + underlying_buffer_.get() + underlying_buffer_size_ - + underlying_buffer_aligned_ptr_, + new_alloc + required_size - new_underlying_buffer_aligned_ptr); + memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_, + copy_amount); + } + + underlying_buffer_.reset(new_alloc); + underlying_buffer_size_ = required_size; + underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; + } + commited_ = true; + return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; +} + +TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, + const ArenaAlloc& alloc, + char** output_ptr) { + TF_LITE_ENSURE(context, commited_); + TF_LITE_ENSURE(context, output_ptr != nullptr); + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + return kTfLiteOk; +} + +TfLiteStatus SimpleMemoryArena::Clear() { + commited_ = false; + high_water_mark_ = 0; + allocs_.clear(); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h new file mode 100644 index 0000000000..0d0b7f9ff7 --- /dev/null +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ + +#include +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// This little structure holds the offset and the size for a dynamic memory +// allocation in the memory arena. When the arena is commited and the +// underlying buffer is set, the alloc can be resolved into an actual memory +// pointer. +struct ArenaAlloc { + ArenaAlloc() : offset(0), size(0) {} + + size_t offset; + size_t size; + + inline bool operator<(const ArenaAlloc& other) const { + return offset < other.offset; + } +}; + +// This small class is responsible for allocating, dealocating and reusing +// dynamic memory from a common underlying buffer. The arena can be used in +// scenarios when the pattern of memory allocations and dealocations is +// repetitive, e.g. running NN inference in multiple iterations. +class SimpleMemoryArena { + public: + explicit SimpleMemoryArena(size_t arena_alignment) + : commited_(false), + arena_alignment_(arena_alignment), + high_water_mark_(0), + underlying_buffer_size_(0), + allocs_() {} + + TfLiteStatus Allocate(TfLiteContext* context, size_t alignment, size_t size, + ArenaAlloc* new_alloc); + + TfLiteStatus Deallocate(TfLiteContext* context, const ArenaAlloc& alloc); + + inline size_t RequiredBufferSize() { + // Add in a small amount of padding to reduce the chance of resize events + // for small allocations. + size_t padding = arena_alignment_; + return arena_alignment_ + high_water_mark_ + padding; + } + + TfLiteStatus Commit(TfLiteContext* context); + + TfLiteStatus ResolveAlloc(TfLiteContext* context, const ArenaAlloc& alloc, + char** output_ptr); + + TfLiteStatus Clear(); + + private: + bool commited_; + size_t arena_alignment_; + size_t high_water_mark_; + std::unique_ptr underlying_buffer_; + size_t underlying_buffer_size_; + char* underlying_buffer_aligned_ptr_; + // TODO(maciekc): add list iterator to the ArenaAlloc to lookup quickly. + std::list allocs_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc new file mode 100644 index 0000000000..ac676092c6 --- /dev/null +++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +#include +#include + +namespace tflite { +namespace { + +TEST(SimpleMemoryArenaTest, BasicArenaOperations) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[6]; + + arena.Allocate(&context, 32, 2047, &allocs[0]); + arena.Allocate(&context, 32, 2047, &allocs[1]); + arena.Allocate(&context, 32, 2047, &allocs[2]); + arena.Deallocate(&context, allocs[0]); + arena.Allocate(&context, 32, 1023, &allocs[3]); + arena.Allocate(&context, 32, 2047, &allocs[4]); + arena.Deallocate(&context, allocs[1]); + arena.Allocate(&context, 32, 1023, &allocs[5]); + + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 2048); + EXPECT_EQ(allocs[2].offset, 4096); + EXPECT_EQ(allocs[3].offset, 0); + EXPECT_EQ(allocs[4].offset, 6144); + EXPECT_EQ(allocs[5].offset, 1024); +} + +TEST(SimpleMemoryArenaTest, TestAfterClear) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[9]; + + arena.Allocate(&context, 32, 2047, &allocs[0]); + arena.Allocate(&context, 32, 2047, &allocs[1]); + arena.Allocate(&context, 32, 2047, &allocs[2]); + arena.Commit(&context); + + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 2048); + EXPECT_EQ(allocs[2].offset, 4096); + + arena.Clear(); + + // Test with smaller allocs. + arena.Allocate(&context, 32, 1023, &allocs[3]); + arena.Allocate(&context, 32, 1023, &allocs[4]); + arena.Allocate(&context, 32, 1023, &allocs[5]); + arena.Commit(&context); + + EXPECT_EQ(allocs[3].offset, 0); + EXPECT_EQ(allocs[4].offset, 1024); + EXPECT_EQ(allocs[5].offset, 2048); + + arena.Clear(); + + // Test larger allocs which should require a reallocation. + arena.Allocate(&context, 32, 4095, &allocs[6]); + arena.Allocate(&context, 32, 4095, &allocs[7]); + arena.Allocate(&context, 32, 4095, &allocs[8]); + arena.Commit(&context); + + EXPECT_EQ(allocs[6].offset, 0); + EXPECT_EQ(allocs[7].offset, 4096); + EXPECT_EQ(allocs[8].offset, 8192); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/contrib/lite/string.h new file mode 100644 index 0000000000..ecd6f04ec2 --- /dev/null +++ b/tensorflow/contrib/lite/string.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Abstract string. We don't want even absl at this level. +#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_ +#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_ + +#include +#include "tensorflow/core/platform/platform.h" + +namespace tflite { + +#ifndef PLATFORM_GOOGLE +using std::string; +#endif + +} // namespace tflite + +#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_ diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc new file mode 100644 index 0000000000..cd41299d38 --- /dev/null +++ b/tensorflow/contrib/lite/string_util.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/string_util.h" + +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { +namespace { + +// Convenient method to get pointer to int32_t. +int32_t* GetIntPtr(char* ptr) { return reinterpret_cast(ptr); } +} // namespace + +void DynamicBuffer::AddString(const char* str, size_t len) { + data_.resize(data_.size() + len); + memcpy(data_.data() + offset_.back(), str, len); + offset_.push_back(offset_.back() + len); +} + +void DynamicBuffer::AddString(const StringRef& string) { + AddString(string.str, string.len); +} + +void DynamicBuffer::AddJoinedString(const std::vector& strings, + char separator) { + // Resize the data buffer. + int total_len = strings.size() - 1; + for (StringRef ref : strings) { + total_len += ref.len; + } + data_.resize(data_.size() + total_len); + + int current_idx = 0; + for (StringRef ref : strings) { + char* dst = data_.data() + offset_.back() + current_idx; + + // Fill separator if not first string. + if (current_idx != 0) { + *dst = separator; + ++dst; + ++current_idx; + } + + // Fill content of the string. + memcpy(dst, ref.str, ref.len); + current_idx += ref.len; + } + offset_.push_back(offset_.back() + total_len); +} + +void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { + // Allocate sufficient memory to tensor buffer. + int32_t num_strings = offset_.size() - 1; + // Total bytes include: + // * size of content (data_.size) + // * offset of each tensor (sizeof(int32_t) * num_strings) + // * length of whole buffer (int32_t) + // * num of strings (int32_t). + int32_t bytes = data_.size() // size of content + + sizeof(int32_t) * (num_strings + 2); // size of header + + // Output tensor will take over the ownership of tensor_buffer, and free it + // during Interpreter destruction. + char* tensor_buffer = static_cast(malloc(bytes)); + + // Set num of string + memcpy(tensor_buffer, &num_strings, sizeof(int32_t)); + + // Set offset of strings. + int32_t start = sizeof(int32_t) * (num_strings + 2); + for (int i = 0; i < offset_.size(); i++) { + int32_t offset = start + offset_[i]; + memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t)); + } + + // Copy data of strings. + memcpy(tensor_buffer + start, data_.data(), data_.size()); + + // Set tensor content pointer to tensor_buffer, and release original data. + auto dims = TfLiteIntArrayCreate(1); + dims->data[0] = num_strings; + TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, + tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, + tensor); +} + +int GetStringCount(const TfLiteTensor* tensor) { + // The first integers in the raw buffer is the number of strings. + return *GetIntPtr(tensor->data.raw); +} + +StringRef GetString(const TfLiteTensor* tensor, int string_index) { + int32_t* offset = + GetIntPtr(tensor->data.raw + sizeof(int32_t) * (string_index + 1)); + return { + tensor->data.raw + (*offset), + (*(offset + 1)) - (*offset), + }; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h new file mode 100644 index 0000000000..12872d1123 --- /dev/null +++ b/tensorflow/contrib/lite/string_util.h @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Util methods to read and write String tensors. +// String tensors are considered to be char tensor with protocol. +// [0, 3] 4 bytes: N, num of strings in the tensor in little endian. +// [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian. +// [(N+2)*4, (N+2)*4+3] 4 bytes: length of the whole char buffer. +// [offset(i), offset(i+1) - 1] : content of i-th string. +// Example of a string tensor: +// [ +// 2, 0, 0, 0, # 2 strings. +// 16, 0, 0, 0, # 0-th string starts from index 12. +// 18, 0, 0, 0, # 1-st string starts from index 18. +// 18, 0, 0, 0, # total length of array. +// 'A', 'B', # 0-th string [16..17]: "AB" +// ] # 1-th string, empty +// +// A typical usage: +// In op.Eval(context, node): +// DynamicBuffer buf; +// # Add string "AB" to tensor, string is stored in dynamic buffer. +// buf.AddString("AB", 2); +// # Write content of DynamicBuffer to tensor in format of string tensor +// # described above. +// buf.WriteToTensor(tensor) + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ + +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { + +// Convenient structure to store string pointer and length. +typedef struct { + char* str; + int len; +} StringRef; + +// DynamicBuffer holds temporary buffer that will be used to create a dynamic +// tensor. A typical usage is to initialize a DynamicBuffer object, fill in +// content and call CreateStringTensor in op.Eval(). +class DynamicBuffer { + public: + DynamicBuffer() : offset_({0}) {} + + // Add string to dynamic buffer by resizing the buffer and copying the data. + void AddString(const StringRef& string); + + // Add string to dynamic buffer by resizing the buffer and copying the data. + void AddString(const char* str, size_t len); + + // Join a list of string with separator, and add as a single string to the + // buffer. + void AddJoinedString(const std::vector& strings, char separator); + + // Fill content into a string tensor. + void WriteToTensor(TfLiteTensor* tensor); + + private: + // Data buffer to store contents of strings, not including headers. + std::vector data_; + // Offset of the starting index of each string in data buffer. + std::vector offset_; +}; + +// Return num of strings in a String tensor. +int GetStringCount(const TfLiteTensor* tensor); + +// Get String pointer and length of index-th string in tensor. +// NOTE: This will not create a copy of string data. +StringRef GetString(const TfLiteTensor* tensor, int string_index); +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc new file mode 100644 index 0000000000..5c351638dc --- /dev/null +++ b/tensorflow/contrib/lite/string_util_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/string_util.h" + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +TEST(StringUtil, TestStringUtil) { + Interpreter interpreter; + interpreter.AddTensors(3); + + TfLiteTensor* t0 = interpreter.tensor(0); + t0->type = kTfLiteString; + t0->allocation_type = kTfLiteDynamic; + + TfLiteTensor* t1 = interpreter.tensor(1); + t1->type = kTfLiteString; + t1->allocation_type = kTfLiteDynamic; + + char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'X', 'Y', 'Z'}; + + interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, {}, data, + 15); + TfLiteTensor* t2 = interpreter.tensor(2); + interpreter.AllocateTensors(); + + char s0[] = "ABC"; + string s1 = "DEFG"; + char s2[] = ""; + + // Write strings to tensors + DynamicBuffer buf0; + buf0.AddString(s0, 3); + DynamicBuffer buf1; + buf1.AddString(s1.data(), s1.length()); + buf0.AddString(s2, 0); + buf0.WriteToTensor(t0); + buf1.WriteToTensor(t1); + + // Read strings from tensors. + ASSERT_EQ(GetStringCount(t0), 2); + StringRef str_ref; + str_ref = GetString(t0, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC"); + str_ref = GetString(t0, 1); + ASSERT_EQ(string(str_ref.str, str_ref.len), ""); + ASSERT_EQ(t0->bytes, 19); + + ASSERT_EQ(GetStringCount(t1), 1); + str_ref = GetString(t1, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "DEFG"); + ASSERT_EQ(t1->bytes, 16); + + ASSERT_EQ(GetStringCount(t2), 1); + str_ref = GetString(t2, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "XYZ"); + ASSERT_EQ(t2->bytes, 15); +} + +TEST(StringUtil, TestAddJoinedString) { + Interpreter interpreter; + interpreter.AddTensors(1); + TfLiteTensor* t0 = interpreter.tensor(0); + t0->type = kTfLiteString; + t0->allocation_type = kTfLiteDynamic; + + char s0[] = "ABC"; + char s1[] = "DEFG"; + char s2[] = ""; + char s3[] = "XYZ"; + + DynamicBuffer buf; + buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' '); + buf.WriteToTensor(t0); + + ASSERT_EQ(GetStringCount(t0), 1); + StringRef str_ref; + str_ref = GetString(t0, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC DEFG XYZ"); + ASSERT_EQ(t0->bytes, 25); +} + +TEST(StringUtil, TestEmptyList) { + Interpreter interpreter; + interpreter.AddTensors(1); + TfLiteTensor* t0 = interpreter.tensor(0); + t0->type = kTfLiteString; + t0->allocation_type = kTfLiteDynamic; + DynamicBuffer buf; + buf.WriteToTensor(t0); + + ASSERT_EQ(GetStringCount(t0), 0); + ASSERT_EQ(t0->bytes, 8); +} + +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/testdata/0_subgraphs.bin b/tensorflow/contrib/lite/testdata/0_subgraphs.bin new file mode 100644 index 0000000000..5606898d7f Binary files /dev/null and b/tensorflow/contrib/lite/testdata/0_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testdata/2_subgraphs.bin b/tensorflow/contrib/lite/testdata/2_subgraphs.bin new file mode 100644 index 0000000000..07308ba62b Binary files /dev/null and b/tensorflow/contrib/lite/testdata/2_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testdata/empty_model.bin b/tensorflow/contrib/lite/testdata/empty_model.bin new file mode 100644 index 0000000000..1762ca3938 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/empty_model.bin differ diff --git a/tensorflow/contrib/lite/testdata/multi_add.bin b/tensorflow/contrib/lite/testdata/multi_add.bin new file mode 100644 index 0000000000..e5048a3281 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/multi_add.bin differ diff --git a/tensorflow/contrib/lite/testdata/multi_add.json b/tensorflow/contrib/lite/testdata/multi_add.json new file mode 100644 index 0000000000..97b931dba8 --- /dev/null +++ b/tensorflow/contrib/lite/testdata/multi_add.json @@ -0,0 +1,46 @@ +{ + "version": 1, + "operator_codes": [ + { + "builtin_code": "ADD" + } + ], + "subgraphs": [ + { + "tensors": [ + { "shape": [ 1, 8, 8, 3 ], "name": "a" }, + { "shape": [ 1, 8, 8, 3 ], "name": "b" }, + { "shape": [ 1, 8, 8, 3 ], "name": "c" }, + { "shape": [ 1, 8, 8, 3 ], "name": "d" }, + { "shape": [ 1, 8, 8, 3 ], "name": "i" }, + { "shape": [ 1, 8, 8, 3 ], "name": "x" }, + { "shape": [ 1, 8, 8, 3 ], "name": "y" } + ], + "inputs": [ 0, 1, 2, 3 ], + "outputs": [ 5, 6 ], + "operators": [ + { + "inputs": [ 1, 2 ], + "outputs": [ 4 ], + "builtin_options_type": "AddOptions", + "builtin_options": { + } + }, + { + "inputs": [ 0, 4 ], + "outputs": [ 5 ], + "builtin_options_type": "AddOptions", + "builtin_options": { + } + }, + { + "inputs": [ 3, 4 ], + "outputs": [ 6 ], + "builtin_options_type": "AddOptions", + "builtin_options": { + } + } + ] + } + ] +} diff --git a/tensorflow/contrib/lite/testdata/no_subgraphs.bin b/tensorflow/contrib/lite/testdata/no_subgraphs.bin new file mode 100644 index 0000000000..5606898d7f Binary files /dev/null and b/tensorflow/contrib/lite/testdata/no_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testdata/test_model.bin b/tensorflow/contrib/lite/testdata/test_model.bin new file mode 100644 index 0000000000..2878b1f96e Binary files /dev/null and b/tensorflow/contrib/lite/testdata/test_model.bin differ diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.bin b/tensorflow/contrib/lite/testdata/test_model_broken.bin new file mode 100644 index 0000000000..9fd050cd4a Binary files /dev/null and b/tensorflow/contrib/lite/testdata/test_model_broken.bin differ diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.json b/tensorflow/contrib/lite/testdata/test_model_broken.json new file mode 100644 index 0000000000..b701eb9a25 --- /dev/null +++ b/tensorflow/contrib/lite/testdata/test_model_broken.json @@ -0,0 +1,62 @@ +{ + "subgraphs": [ + { + "inputs": [0, 1], + "outputs": [2, 3], + "operators": [ + { + "opcode_index": 0, + "inputs": [0,1], + "outputs": [2] + }, + { + "opcode_index": 1, + "inputs": [2], + "outputs": [3] + } + ], + "tensors": [ + { + "shape" : [ + 2 + ], + "type" : "FLOAT32", + "name" : "input0", + "data_buffer" : [1,0,0,0] + }, + { + "shape" : [ + 3 + ], + "type" : "FLOAT32", + "name" : "input1", + "data_buffer" : [] + }, + { + "shape" : [ + 3 + ], + "type" : "FLOAT32", + "name" : "out1", + "data_buffer" : [] + }, + { + "shape" : [ + 3 + ], + "type" : "FLOAT32", + "name" : "out2", + "data_buffer" : [] + } + ], + } + ], + "operator_codes": [ + { + "builtin_code": 0 + }, + { + "custom_code": "testing_op" + } + ] +} diff --git a/tensorflow/contrib/lite/testdata/two_subgraphs.bin b/tensorflow/contrib/lite/testdata/two_subgraphs.bin new file mode 100644 index 0000000000..07308ba62b Binary files /dev/null and b/tensorflow/contrib/lite/testdata/two_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD new file mode 100644 index 0000000000..5e40a13d3c --- /dev/null +++ b/tensorflow/contrib/lite/testing/BUILD @@ -0,0 +1,213 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "gen_zipped_test_files", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +gen_zipped_test_files( + name = "optest", + files = [ + "add.zip", + "avg_pool.zip", + "concat.zip", + "constant.zip", + "control_dep.zip", + "conv.zip", + "depthwiseconv.zip", + "fully_connected.zip", + "fused_batch_norm.zip", + "global_batch_norm.zip", + "l2_pool.zip", + "l2norm.zip", + "local_response_norm.zip", + "max_pool.zip", + "mul.zip", + "relu.zip", + "relu1.zip", + "relu6.zip", + "reshape.zip", + "resize_bilinear.zip", + "sigmoid.zip", + "softmax.zip", + "space_to_depth.zip", + ], +) + +py_binary( + name = "generate_examples", + srcs = ["generate_examples.py"], + data = [ + "//tensorflow/contrib/lite/toco", + ], + srcs_version = "PY2AND3", + deps = [ + ":generate_examples_report", + "//tensorflow:tensorflow_py", + "//tensorflow/python:graph_util", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "generate_examples_report", + srcs = ["generate_examples_report.py"], + srcs_version = "PY2AND3", +) + +cc_library( + name = "parse_testdata_lib", + srcs = ["parse_testdata.cc"], + hdrs = ["parse_testdata.h"], + deps = [ + ":message", + ":split", + ":test_runner", + "//tensorflow/contrib/lite:framework", + ], +) + +cc_library( + name = "message", + srcs = ["message.cc"], + hdrs = ["message.h"], + deps = [":tokenize"], +) + +cc_test( + name = "message_test", + srcs = ["message_test.cc"], + deps = [ + ":message", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "split", + srcs = ["split.cc"], + hdrs = ["split.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "split_test", + size = "small", + srcs = ["split_test.cc"], + deps = [ + ":split", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tflite_driver", + srcs = ["tflite_driver.cc"], + hdrs = ["tflite_driver.h"], + deps = [ + ":split", + ":test_runner", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "tflite_driver_test", + size = "small", + srcs = ["tflite_driver_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + deps = [ + ":tflite_driver", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tokenize", + srcs = ["tokenize.cc"], + hdrs = ["tokenize.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "tokenize_test", + srcs = ["tokenize_test.cc"], + deps = [ + ":tokenize", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "test_runner", + hdrs = ["test_runner.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "test_runner_test", + srcs = ["test_runner_test.cc"], + deps = [ + ":test_runner", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "nnapi_example", + srcs = ["nnapi_example.cc"], + deps = [ + ":parse_testdata_lib", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "generated_examples_zip_test", + size = "medium", + srcs = ["generated_examples_zip_test.cc"], + data = [":optest"], + shard_count = 10, + deps = [ + ":parse_testdata_lib", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_googletest//:gtest", + "@com_googlesource_code_re2//:re2", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py new file mode 100644 index 0000000000..86540d58a6 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -0,0 +1,1189 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generate a series of TensorFlow graphs that become tflite test cases. + +Usage: + +generate_examples zipped + +bazel run //tensorflow/contrib/lite/testing:generate_examples + third_party/tensorflow/contrib/lite/testing/generated_examples zipped +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import itertools +import os +import re +import sys +import tempfile +import traceback +import zipfile +import numpy as np +from six import StringIO +import tensorflow as tf +from google.protobuf import text_format +# TODO(aselle): switch to TensorFlow's resource_loader +from tensorflow.contrib.lite.testing import generate_examples_report as report_lib +from tensorflow.python.framework import graph_util as tf_graph_util + +parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") +parser.add_argument("output_path", + help="Directory where the outputs will be go.") +# TODO(ahentz): remove this flag +parser.add_argument("type", help="zipped") +parser.add_argument("--zip_to_output", + type=str, + help="Particular zip to output.", + required=False) +parser.add_argument("--toco", + type=str, + help="Path to toco tool.", + required=True) +parser.add_argument( + "--known_bugs_are_errors", + action="store_true", + help=("If a particular model is affected by a known bug," + " count it as a toco error.")) +parser.add_argument( + "--ignore_toco_errors", + action="store_true", + help="Raise an exception if any toco error is encountered.") +parser.add_argument( + "--save_graphdefs", + action="store_true", + help="Include intermediate graphdefs in the output zip files.") + + +RANDOM_SEED = 342 +TEST_INPUT_DEPTH = 3 + + +# A map from regular expression to bug number. Any test failure with label +# matching the expression will be considered due to the corresponding bug. +KNOWN_BUGS = { + # TOCO doesn't support scalars as input. + r"relu.*input_shape=\[\]": "67587484", + r"sigmoid.*input_shape=\[\]": "67645668", + # Concat doesn't work with a single input tensor + r"concat.*num_tensors=1": "67378344", + # Transposition in MatMul is not supported. + r"fully_connected.*transpose_.=True": "67586970", + # Softmax graphs are too complex. + r"softmax.*dim=0": "67749831", + r"softmax.*input_shape=\[1,3,4,3\]": "67749831", + # SpaceToDepth only supports float32. + r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", +} + + +def toco_options(data_types, + input_arrays, + output_arrays, + shapes, + drop_control_dependency): + """Create TOCO options to process a model. + + Args: + data_types: input and inference types used by TOCO. + input_arrays: names of the input tensors + output_arrays: name of the output tensors + shapes: shapes of the input tensors + drop_control_dependency: whether to ignore control dependency nodes. + + Returns: + the options in a string. + """ + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + inference_type = "FLOAT" + # TODO(ahentz): if we get multi-input quantization to work we need this + # to change + if data_types[0] == "QUANTIZED_UINT8": + inference_type = "QUANTIZED_UINT8" + s = (" --input_types=%s" % ",".join(data_types) + + " --inference_type=%s" % inference_type + + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + + " --input_arrays=%s" % ",".join(input_arrays) + + " --input_shapes=%s" % shape_str + + " --output_arrays=%s" % ",".join(output_arrays)) + if drop_control_dependency: + s += " --drop_control_dependency" + return s + + +def write_toco_options(filename, + data_types, + input_arrays, + output_arrays, + shapes, + drop_control_dependency=False): + """Create TOCO options to process a model. + + Args: + filename: Filename to write the options to. + data_types: input and inference types used by TOCO. + input_arrays: names of the input tensors + output_arrays: names of the output tensors + shapes: shapes of the input tensors + drop_control_dependency: whether to ignore control dependency nodes. + """ + with open(filename, "w") as fp: + fp.write( + toco_options( + data_types=data_types, + input_arrays=input_arrays, + output_arrays=output_arrays, + shapes=shapes, + drop_control_dependency=drop_control_dependency)) + + +def write_examples(fp, examples): + """Given a list `examples`, write a text format representation. + + The file format is csv like with a simple repeated pattern. We would ike + to use proto here, but we can't yet due to interfacing with the Android + team using this format. + + Args: + fp: File-like object to write to. + examples: Example dictionary consiting of keys "inputs" and "outputs" + """ + + def write_tensor(fp, x): + """Write tensor in file format supported by TFLITE example.""" + fp.write("dtype,%s\n" % x.dtype) + fp.write("shape," + ",".join(map(str, x.shape)) + "\n") + # Output 9 digits after the point to ensure the precision is good enough. + values = ["{:.9f}".format(value) for value in list(x.flatten())] + fp.write("values," + ",".join(values) + "\n") + + fp.write("test_cases,%d\n" % len(examples)) + for example in examples: + fp.write("inputs,%d\n" % len(example["inputs"])) + for i in example["inputs"]: + write_tensor(fp, i) + fp.write("outputs,%d\n" % len(example["outputs"])) + for i in example["outputs"]: + write_tensor(fp, i) + + +def write_test_cases(fp, model_name, examples): + """Given a dictionary of `examples`, write a text format representation. + + The file format is protocol-buffer-like, even though we don't use proto due + to the needs of the Android team. + + Args: + fp: File-like object to write to. + model_name: Filename where the model was written to, relative to filename. + examples: Example dictionary consiting of keys "inputs" and "outputs" + """ + + fp.write("load_model: %s\n" % os.path.basename(model_name)) + for example in examples: + fp.write("reshape {\n") + for t in example["inputs"]: + fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n") + fp.write("}\n") + fp.write("invoke {\n") + + for t in example["inputs"]: + values = ["{:.9f}".format(value) for value in list(t.flatten())] + fp.write(" input: \"" + ",".join(values) + "\"\n") + for t in example["outputs"]: + values = ["{:.9f}".format(value) for value in list(t.flatten())] + fp.write(" output: \"" + ",".join(values) + "\"\n") + fp.write("}\n") + + +_TF_TYPE_INFO = { + tf.float32: (np.float32, "FLOAT"), + tf.float16: (np.float16, "FLOAT"), + tf.int32: (np.int32, "INT32"), + tf.uint8: (np.uint8, "QUANTIZED_UINT8"), + tf.int64: (np.int64, "INT64"), +} + + +def create_tensor_data(dtype, shape, min_value=-100, max_value=100): + """Build tensor data spreading the range [min_value, max_value).""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value-min_value)*np.random.random_sample(shape)+min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.random_integers(min_value, max_value, shape) + return value.astype(dtype) + + +def freeze_graph(session, outputs): + """Freeze the current graph. + + Args: + session: Tensorflow sessions containing the graph + outputs: List of output tensors + + Returns: + The frozen graph_def. + """ + return tf_graph_util.convert_variables_to_constants( + session, session.graph.as_graph_def(), [x.op.name for x in outputs]) + + +def make_control_dep_tests(zip_path): + """Make a set of tests that use control dependencies.""" + + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32) + assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1) + with tf.control_dependencies([assert_op]): + out = tf.nn.conv2d(input_tensor, filter_value, + strides=(1, 1, 1, 1), padding="SAME") + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, + drop_control_dependency=True) + + +def toco_convert(graph_def_str, input_tensors, output_tensors, + drop_control_dependency=False): + """Convert a model's graph def into a tflite model. + + NOTE: this currently shells out to the toco binary, but we would like + convert to Python API tooling in the future. + + Args: + graph_def_str: Graph def proto in serialized string format. + input_tensors: List of input tensor tuples `(name, shape, type)` + output_tensors: List of output tensors (names) + drop_control_dependency: whether to ignore control dependency nodes. + + Returns: + output tflite model, log_txt from conversion + or None, log_txt if it did not convert properly. + """ + data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors] + opts = toco_options( + data_types=data_types, + input_arrays=[x[0] for x in input_tensors], + shapes=[x[1] for x in input_tensors], + output_arrays=output_tensors, + drop_control_dependency=drop_control_dependency) + + with tempfile.NamedTemporaryFile() as graphdef_file, \ + tempfile.NamedTemporaryFile() as output_file, \ + tempfile.NamedTemporaryFile("w+") as stdout_file: + graphdef_file.write(graph_def_str) + graphdef_file.flush() + + # TODO(aselle): Switch this to subprocess at some point. + cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" % + (bin_path, graphdef_file.name, output_file.name, opts, + stdout_file.name)) + exit_code = os.system(cmd) + log = ( + cmd + "exited with code %d" % exit_code + "\n------------------\n" + + stdout_file.read()) + return (None if exit_code != 0 else output_file.read()), log + + +def make_zip_of_tests(zip_path, + test_parameters, + make_graph, + make_test_inputs, + drop_control_dependency=False): + """Helper to make a zip file of a bunch of TensorFlow models. + + This does a cartestian product of the dictionary of test_parameters and + calls make_graph() for each item in the cartestian product set. + If the graph is built successfully, then make_test_inputs() is called to + build expected input/output value pairs. The model is then converted to tflite + with toco, and the examples are serialized with the tflite model into a zip + file (2 files per item in the cartesian product set). + + Args: + zip_path: Path of zip file to write + test_parameters: Dictionary mapping to lists for each parameter. + e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}` + make_graph: function that takes current parameters and returns tuple + `[input1, input2, ...], [output1, output2, ...]` + make_test_inputs: function taking `curr_params`, `session`, `input_tensors`, + `output_tensors` and returns tuple `(input_values, output_values)`. + drop_control_dependency: whether to ignore control dependency nodes. + Raises: + RuntimeError: if there are toco errors that can't be ignored. + """ + + # TODO(aselle): Make this allow multiple inputs outputs. + archive = zipfile.PyZipFile(zip_path, "w") + zip_manifest = [] + convert_report = [] + toco_errors = 0 + for parameters in test_parameters: + keys = parameters.keys() + for curr in itertools.product(*parameters.values()): + label = zip_path.replace(".zip", "") + (",".join( + "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) + if label[0] == "/": + label = label[1:] + param_dict = dict(zip(keys, curr)) + + def build_example(label, param_dict_real): + """Build the model with parameter values set in param_dict_real. + + Args: + label: Label of the model (i.e. the filename in the zip). + param_dict_real: Parameter dictionary (arguments to the factories + make_graph and make_test_inputs) + Returns: + (tflite_model_binary, report) where tflite_model_binary is the + serialized flatbuffer as a string and report is a dictionary with + keys `toco_log` (log of toco conversion), `tf_log` (log of tf + conversion), `toco` (a string of success status of the conversion), + `tf` (a string success status of the conversion). + """ + + np.random.seed(RANDOM_SEED) + report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED} + + # Build graph + report["tf_log"] = "" + report["toco_log"] = "" + tf.reset_default_graph() + + try: + inputs, outputs = make_graph(param_dict_real) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report + + sess = tf.Session() + try: + baseline_inputs, baseline_outputs = (make_test_inputs( + param_dict_real, sess, inputs, outputs)) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report + report["toco"] = report_lib.FAILED + report["tf"] = report_lib.SUCCESS + + # Convert graph to toco + tflite_model_binary, toco_log = toco_convert( + sess.graph_def.SerializeToString(), + [(input_tensor.name.split(":")[0], input_tensor.get_shape(), + input_tensor.dtype) for input_tensor in inputs], + [out.name.split(":")[0] + for out in outputs], drop_control_dependency) + report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None + else report_lib.FAILED) + report["toco_log"] = toco_log + + if FLAGS.save_graphdefs: + archive.writestr(label + ".pb", + text_format.MessageToString(sess.graph_def), + zipfile.ZIP_DEFLATED) + + if tflite_model_binary: + archive.writestr(label + ".bin", tflite_model_binary, + zipfile.ZIP_DEFLATED) + example = {"inputs": baseline_inputs, "outputs": baseline_outputs} + + example_fp = StringIO() + write_examples(example_fp, [example]) + archive.writestr(label + ".inputs", + example_fp.getvalue(), zipfile.ZIP_DEFLATED) + + example_fp2 = StringIO() + write_test_cases(example_fp2, label + ".bin", [example]) + archive.writestr(label + "_tests.txt", + example_fp2.getvalue(), zipfile.ZIP_DEFLATED) + + zip_manifest.append(label + "\n") + + return tflite_model_binary, report + + _, report = build_example(label, param_dict) + + if report["toco"] == report_lib.FAILED: + ignore_error = False + if not FLAGS.known_bugs_are_errors: + for pattern, bug_number in KNOWN_BUGS.items(): + if re.search(pattern, label): + print("Ignored TOCO error due to bug %s" % bug_number) + ignore_error = True + if not ignore_error: + toco_errors += 1 + print("-----------------\ntoco error!\n%s\n-----------------\n" % + report["toco_log"]) + + convert_report.append((param_dict, report)) + report_io = StringIO() + report_lib.make_report_table(report_io, zip_path, convert_report) + archive.writestr("report.html", report_io.getvalue()) + + archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED) + + # Log statistics of what succeeded + total_conversions = len(convert_report) + tf_success = sum(1 for x in convert_report + if x[1]["tf"] == report_lib.SUCCESS) + toco_success = sum(1 for x in convert_report + if x[1]["toco"] == report_lib.SUCCESS) + percent = 0 + if tf_success > 0: + percent = float(toco_success) / float(tf_success) * 100. + tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs " + " and %d TOCO converted graphs (%.1f%%"), zip_path, + total_conversions, tf_success, toco_success, percent) + + if not FLAGS.ignore_toco_errors and toco_errors > 0: + raise RuntimeError( + "Found %d errors while generating toco models" % toco_errors) + + +def make_pool_tests(pool_op_in): + """Make a set of tests to do average pooling. + + Args: + pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool`. + + Returns: + A function representing the true generator (after curried pool_op_in). + """ + + pool_op = pool_op_in + + def f(zip_path): + """Actual function that generates examples. + + Args: + zip_path: path to write zip to. + """ + + # Chose a set of parameters + test_parameters = [{ + "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]). + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = pool_op( + input_tensor, + ksize=parameters["ksize"], + strides=parameters["strides"], + data_format=parameters["data_format"], + padding=parameters["padding"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + return f + + +def make_relu_tests(zip_path): + """Make a set of tests to do relu.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.relu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_relu1_tests(zip_path): + """Make a set of tests to do relu1.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + # Note that the following is not supported: + # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0)) + out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0)) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-3, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_relu6_tests(zip_path): + """Make a set of tests to do relu6.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.relu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-3, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +# This function tests various TensorFLow functions that generates Const op, +# including `tf.ones`, `tf.zeros` and random functions. +def make_constant_tests(zip_path): + """Make a set of tests to do constant ops.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], + }] + + def build_graph(parameters): + # Since Toco & Tflite can't have a single constant op in the entire graph, + # this test adds a zero tesnor with a constant op tensor. + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape"]) + out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 + return [input1], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = np.zeros(parameters["input_shape"], + dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) + return [input1], sess.run(outputs, feed_dict={inputs[0]: input1}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_add_tests(zip_path): + """Make a set of tests to do add with and without broadcast.""" + + # These parameters are split because we don't support broadcasting. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + }] + + def build_graph(parameters): + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape_1"]) + input2 = tf.placeholder(dtype=parameters["dtype"], name="input2", + shape=parameters["input_shape_2"]) + out = tf.add(input1, input2) + return [input1, input2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) + return [input1, input2], sess.run( + outputs, feed_dict={ + inputs[0]: input1, + inputs[1]: input2 + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_mul_tests(zip_path): + """Make a set of tests to do mul with and without broadcast.""" + + # These parameters are split because we don't support broadcasting. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + }] + + def build_graph(parameters): + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape_1"]) + input2 = tf.placeholder(dtype=parameters["dtype"], name="input2", + shape=parameters["input_shape_2"]) + out = tf.multiply(input1, input2) + return [input1, input2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) + return [input1, input2], sess.run( + outputs, feed_dict={inputs[0]: input1, + inputs[1]: input2}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_global_batch_norm_tests(zip_path): + """Make a set of tests to do batch_norm_with_global_normalization.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]], + "epsilon": [0.1, 0.0001], + "scale_after": [True, False], + }] + + def build_graph(parameters): + """Build the global batch norm testing graph.""" + input_shape = parameters["input_shape"] + scale_shape = input_shape[3] + + scale = create_tensor_data(parameters["dtype"], scale_shape) + offset = create_tensor_data(parameters["dtype"], scale_shape) + mean = create_tensor_data(parameters["dtype"], scale_shape) + variance = create_tensor_data(parameters["dtype"], scale_shape) + + x = create_tensor_data(parameters["dtype"], parameters["input_shape"]) + x_norm = tf.nn.batch_norm_with_global_normalization( + x, mean, variance, scale, offset, + parameters["epsilon"], parameters["scale_after"]) + + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.add(input_tensor, x_norm) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_fused_batch_norm_tests(zip_path): + """Make a set of tests to do fused_batch_norm.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 1, 6, 2]], + "epsilon": [0.001, 0.1], + }] + + def build_graph(parameters): + """Build the testing graph for fused batch normalization.""" + input_shape = parameters["input_shape"] + scale_shape = input_shape[3] + + scale = create_tensor_data(parameters["dtype"], scale_shape) + offset = create_tensor_data(parameters["dtype"], scale_shape) + mean = create_tensor_data(parameters["dtype"], scale_shape) + variance = create_tensor_data(parameters["dtype"], scale_shape) + + x = create_tensor_data(parameters["dtype"], parameters["input_shape"]) + [x_norm, _, _] = tf.nn.fused_batch_norm( + x, scale, offset, mean, variance, + parameters["epsilon"], data_format="NHWC", is_training=False) + + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.add(input_tensor, x_norm) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_conv_tests(zip_path): + """Make a set of tests to do convolution.""" + + test_parameters = [{ + "input_shape": [[1, 3, 4, 3]], + "filter_shape": [[1, 1, 3, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }, { + "input_shape": [[2, 14, 14, 2]], + "filter_shape": [[6, 6, 2, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + filter_values = create_tensor_data(np.float32, parameters["filter_shape"]) + out = tf.nn.conv2d(input_tensor, filter_values, + strides=parameters["strides"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(np.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_depthwiseconv_tests(zip_path): + """Make a set of tests to do convolution.""" + + # Tensorflow only supports equal strides + test_parameters = [{ + "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], + "filter_size": [[1, 1], [1, 2], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "channel_multiplier": [1, 2], + "rate": [[1, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], + }, { + "input_shape": [[1, 3, 4, 3]], + "filter_size": [[1, 1]], + "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1] + "channel_multiplier": [2], + "rate": [[2, 2]], # Only [1, 1] is supported + "padding": ["SAME"], + "data_format": ["NHWC"], + }] + + def build_graph(parameters): + """Build a depthwise conv graph given `parameters`.""" + input_shape = parameters["input_shape"] + filter_size = parameters["filter_size"] + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"]] + filter_values = create_tensor_data(np.float32, filter_shape) + out = tf.nn.depthwise_conv2d( + input_tensor, filter_values, + strides=parameters["strides"], + rate=parameters["rate"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(np.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_concatenation_tests(zip_path): + """Make a set of tests to do concatenatinon.""" + + test_parameters = [{ + "base_shape": [[1, 3, 4, 3], [3, 4]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + }] + + def get_shape(parameters, delta): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += delta + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n), + shape=get_shape(parameters, n)) + all_tensors.append(input_tensor) + out = tf.concat(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for n in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, + get_shape(parameters, n)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_fully_connected_tests(zip_path): + """Make a set of tests to do fully_connected.""" + + test_parameters = [{ + "shape1": [[3, 3]], + "shape2": [[3, 3]], + "transpose_a": [True, False], + "transpose_b": [True, False], + }, { + "shape1": [[4, 4], [1, 4], [4]], + "shape2": [[4, 4], [4, 1], [4]], + "transpose_a": [False], + "transpose_b": [False], + }, { + "shape1": [[40, 37]], + "shape2": [[37, 40]], + "transpose_a": [False], + "transpose_b": [False], + + }] + + def build_graph(parameters): + input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1", + shape=parameters["shape1"]) + input_tensor2 = create_tensor_data(np.float32, parameters["shape2"]) + out = tf.matmul(input_tensor1, input_tensor2, + transpose_a=parameters["transpose_a"], + transpose_b=parameters["transpose_b"]) + return [input_tensor1], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"]) + return [input_values1], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values1]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_l2norm_tests(zip_path): + """Make a set of tests to do l2norm.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + "dim": [0, 1, 2, 3, [2, 3], -2], + "epsilon": [None, 1e-12, 1e-3], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + if parameters["epsilon"]: + out = tf.nn.l2_normalize( + input_tensor, parameters["dim"], epsilon=parameters["epsilon"]) + else: + out = tf.nn.l2_normalize(input_tensor, parameters["dim"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_local_response_norm_tests(zip_path): + """Make a set of tests to do local_response_norm.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], + "depth_radius": [None, 0, 1, 3, 4, 5], + "bias": [None, 0.1, 0.3, -0.1], + "alpha": [None, 1, 2, -3], + "beta": [None, 0.5, 0.25, 2], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.local_response_normalization( + input_tensor, depth_radius=parameters["depth_radius"], + bias=parameters["bias"], alpha=parameters["alpha"], + beta=parameters["beta"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_reshape_tests(zip_path): + """Make a set of tests to do reshape.""" + + # Alll shapes below are suitable for tensors with 420 elements. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], + "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.reshape(input_tensor, shape=parameters["output_shape"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_resize_bilinear_tests(zip_path): + """Make a set of tests to do resize_bilinear.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]], + "size": [[1, 1], [4, 3], [2, 2], [5, 6]], + "align_corners": [None, True, False], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.image.resize_bilinear(input_tensor, size=parameters["size"], + align_corners=parameters["align_corners"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_sigmoid_tests(zip_path): + """Make a set of tests to do sigmoid.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.sigmoid(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_softmax_tests(zip_path): + """Make a set of tests to do softmax.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [2, 3]], + "dim": [-1, 0], + }, { + "dtype": [tf.float32], + "input_shape": [[4, 7]], + "dim": [-1, 1], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.nn.softmax(input_tensor, dim=parameters["dim"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_space_to_depth_tests(zip_path): + """Make a set of tests to do space_to_depth.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "input_shape": [[2, 12, 24, 1]], + "block_size": [2, 3, 4], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_l2_pool(input_tensor, ksize, strides, padding, data_format): + """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" + return tf.sqrt(tf.nn.avg_pool( + tf.square(input_tensor), ksize=ksize, strides=strides, + padding=padding, data_format=data_format)) + + +# Toco binary path provided by the generate rule. +bin_path = None + + +def main(unused_args): + global bin_path + def mkdir_if_not_exist(x): + if not os.path.isdir(x): + os.mkdir(x) + if not os.path.isdir(x): + raise RuntimeError("Failed to create dir %r" % x) + + if FLAGS.type == "zipped": + opstest_path = os.path.join(FLAGS.output_path) + mkdir_if_not_exist(opstest_path) + def _path(filename): + return os.path.join(opstest_path, filename) + + dispatch = { + "control_dep.zip": make_control_dep_tests, + "add.zip": make_add_tests, + "conv.zip": make_conv_tests, + "constant.zip": make_constant_tests, + "depthwiseconv.zip": make_depthwiseconv_tests, + "concat.zip": make_concatenation_tests, + "fully_connected.zip": make_fully_connected_tests, + "global_batch_norm.zip": make_global_batch_norm_tests, + "fused_batch_norm.zip": make_fused_batch_norm_tests, + "l2norm.zip": make_l2norm_tests, + "local_response_norm.zip": make_local_response_norm_tests, + "mul.zip": make_mul_tests, + "relu.zip": make_relu_tests, + "relu1.zip": make_relu1_tests, + "relu6.zip": make_relu6_tests, + "l2_pool.zip": make_pool_tests(make_l2_pool), + "avg_pool.zip": make_pool_tests(tf.nn.avg_pool), + "max_pool.zip": make_pool_tests(tf.nn.max_pool), + "reshape.zip": make_reshape_tests, + "resize_bilinear.zip": make_resize_bilinear_tests, + "sigmoid.zip": make_sigmoid_tests, + "softmax.zip": make_softmax_tests, + "space_to_depth.zip": make_space_to_depth_tests, + } + out = FLAGS.zip_to_output + bin_path = FLAGS.toco + if out in dispatch: + dispatch[out](_path(out)) + else: + raise RuntimeError("Invalid zip to output %r" % out) + + else: + raise RuntimeError("Invalid argument for type of generation.") + + +if __name__ == "__main__": + FLAGS, unparsed = parser.parse_known_args() + + if unparsed: + print("Usage: %s zipped ") + else: + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/contrib/lite/testing/generate_examples_report.py new file mode 100644 index 0000000000..7bcf8cd86a --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_examples_report.py @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Make HTML tables that report where TF and TOCO failed to convert models. + +This is primarily used by generate_examples.py. See it or +`make_report_table` for more details on usage. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cgi +import json + +FAILED = "FAILED" +SUCCESS = "SUCCESS" +NOTRUN = "NOTRUN" + + +def make_report_table(fp, title, reports): + """Make an HTML report of the success/failure reports. + + Args: + fp: File-like object in which to put the html. + title: "Title of the zip file this pertains to." + reports: a list of conversion attempts. (report_args, report_vals) i.e. + ({"shape": [1,2,3], "type": "tf.float32"}, + {"tf": "SUCCESS", "toco": "FAILURE", "toco_log": "Unsupported type.", + "tf_log": ""}) + """ + # sort reports by if TOCO failure and then TF failure (reversed) + reports.sort(key=lambda x: x[1]["toco"], reverse=False) + reports.sort(key=lambda x: x[1]["tf"], reverse=True) + def result_cell(x, row, col): + """Produce a cell with the condition string `x`.""" + s = cgi.escape(repr(x), quote=True) + color = "#44ff44" if x == SUCCESS else ( + "#ff4444" if x == FAILED else "#eeeeee") + handler = "ShowLog(%d, %d)" % (row, col) + fp.write("%s\n" % ( + color, handler, s)) + + fp.write(""" + +tflite report + + +""") + # Write the log data to a javascript variable and also make a function + # in javascript to show the log when an item is clicked. + fp.write("\n") + + # Write the main table and use onclick on the items that have log items. + fp.write(""" + +

TOCO Conversion

+

%s

+""" % title) + + # Get a list of keys that are in any of the records. + param_keys = {} + for params, _ in reports: + for k in params.keys(): + param_keys[k] = True + + fp.write("\n") + fp.write("\n") + fp.write("\n") + fp.write("
\n") + fp.write("
\n") + fp.write("\n") + fp.write("\n") + for p in param_keys: + fp.write("\n" % cgi.escape(p, quote=True)) + fp.write("\n") + fp.write("\n") + fp.write("\n") + for idx, (params, vals) in enumerate(reports): + fp.write("\n") + for p in param_keys: + fp.write(" \n" % cgi.escape(repr(params[p]), quote=True)) + + result_cell(vals["tf"], idx, 0) + result_cell(vals["toco"], idx, 1) + fp.write("\n") + fp.write("
%sTensorFlowTOCO
%s
\n") + fp.write("
\n") + fp.write("
\n") + fp.write("\n") + fp.write(""" + + + """) diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc new file mode 100644 index 0000000000..e7df97ee54 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include "re2/re2.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +bool FLAGS_ignore_known_bugs = true; +} // namespace + +namespace tflite { +namespace testing { + +// TensorFlow system environment for file system called. +tensorflow::Env* env = tensorflow::Env::Default(); + +// List of tests that are expected to fail when +// --test_arg=--ignore_known_bugs=false +// Key is a substring of the test name and value is a bug number. +// TODO(ahentz): make sure we clean this list up frequently. +std::map kBrokenTests = { + // Add doesn't support broadcasting. + {R"(addd.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(muld.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + + // Add only supports float32. (and "constant" tests use Add) + {R"(addd.*int32)", "68808744"}, + {R"(constant.*int32)", "68808744"}, + {R"(mul.*int32)", "68808744"}, + + // Toco or TFLite has a bug to deal with some constant functions with + // more than 1 element. + {R"(constant.*input_shape=\[(2|2,2,2,2)\])", "68721522"}, + + // L2Norm only supports 4D tensors. + {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.\])", "67963684"}, + {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + + // L2Norm only works for dim=-1. + {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + + // ResizeBilinear looks completely incompatible with Tensorflow + {R"(resize_bilinear)", "67964336"}, +}; + +// Allows test data to be unzipped into a temporary directory and makes +// sure those temporary directories are removed later. +class ZipEnvironment : public ::testing::Environment { + public: + ~ZipEnvironment() override {} + + // Delete all temporary directories on teardown. + void TearDown() override { + for (const auto& dir : temporary_directories_) { + tensorflow::int64 undeleted_dirs, undeleted_files; + TF_CHECK_OK( + env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files)); + } + temporary_directories_.clear(); + } + + // Unzip `zip` file into a new temporary directory `out_dir`. + tensorflow::Status UnZip(const std::string& zip, std::string* out_dir) { + string dir; + TF_CHECK_OK(MakeTemporaryDirectory(&dir)); + tensorflow::SubProcess proc; + std::string unzip_binary = + "/usr/bin/unzip"; + proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip.c_str()}); + proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); + if (!proc.Start()) + return tensorflow::Status(tensorflow::error::UNKNOWN, + "unzip couldn't start"); + string out, err; + int status = proc.Communicate(nullptr, &out, &err); + if (WEXITSTATUS(status) == 0) { + *out_dir = dir; + return tensorflow::Status::OK(); + } else { + return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed"); + } + } + + private: + // Make a temporary directory and return its name in `temporary`. + tensorflow::Status MakeTemporaryDirectory(string* temporary) { + if (env->LocalTempFilename(temporary)) { + TF_CHECK_OK(env->CreateDir(*temporary)); + temporary_directories_.push_back(*temporary); + return tensorflow::Status::OK(); + } + return tensorflow::Status(tensorflow::error::UNKNOWN, + "make temporary directory failed"); + } + + std::vector temporary_directories_; +}; + +// Return the singleton zip_environment. +ZipEnvironment* zip_environment() { + static ZipEnvironment* env = new ZipEnvironment; + return env; +} + +// Read the manifest.txt out of the unarchived zip file. Specifically +// `original_file` is the original zip file for error messages. `dir` is +// the temporary directory where the zip file has been unarchived and +// `test_paths` is the list of test prefixes that were in the manifest. +// Note, it is an error for a manifest to contain no tests. +tensorflow::Status ReadManifest(const std::string& original_file, + const std::string& dir, + std::vector* test_paths) { + // Read the newline delimited list of entries in the manifest. + std::ifstream manifest_fp(dir + "/manifest.txt"); + std::string manifest((std::istreambuf_iterator(manifest_fp)), + std::istreambuf_iterator()); + size_t pos = 0; + int added = 0; + while (true) { + size_t end_pos = manifest.find("\n", pos); + if (end_pos == std::string::npos) break; + std::string filename = manifest.substr(pos, end_pos - pos); + test_paths->push_back(dir + "/" + filename); + pos = end_pos + 1; + added += 1; + } + if (!added) { + std::string message = "Test had no examples: " + original_file; + return tensorflow::Status(tensorflow::error::UNKNOWN, message.c_str()); + } + return tensorflow::Status::OK(); +} + +// Get a list of tests from a zip file `zip_file_name`. +std::vector UnarchiveZipAndFindTestNames( + const std::string& zip_file_name) { + std::string zip_file = ::tensorflow::testing::TensorFlowSrcRoot() + + "/contrib/lite/testing/optest/" + zip_file_name; + std::string decompress_tmp_dir; + TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); + std::vector stuff; + TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + return stuff; +} + +class OpsTest : public ::testing::TestWithParam {}; + +TEST_P(OpsTest, RunStuff) { + std::string test_path = GetParam(); + std::string tflite_file = test_path + ".bin"; + std::string tflite_examples = test_path + ".inputs"; + auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file.c_str()); + std::unique_ptr interpreter; + + tflite::ops::builtin::BuiltinOpResolver builtins; + ASSERT_EQ(tflite::InterpreterBuilder(*model, builtins)(&interpreter), + kTfLiteOk); + + std::vector examples; + ASSERT_EQ(tflite::testing::ParseExamples(tflite_examples.c_str(), &examples), + kTfLiteOk); + + string bug_number; + for (const auto& p : kBrokenTests) { + if (RE2::PartialMatch(test_path, p.first)) { + bug_number = p.second; + } + } + + for (const auto& example : examples) { + ASSERT_EQ(interpreter->inputs().size(), example.inputs.size()); + auto result = [&]() { + TF_LITE_ENSURE_STATUS(FeedExample(interpreter.get(), example)); + TF_LITE_ENSURE_STATUS(interpreter->Invoke()); + TF_LITE_ENSURE_STATUS(CheckOutputs(interpreter.get(), example)); + return kTfLiteOk; + }(); + + if (bug_number.empty()) { + ASSERT_EQ(result, kTfLiteOk); + } else { + if (FLAGS_ignore_known_bugs) { + ASSERT_EQ(result, kTfLiteError) + << "Not failing as expected dut to http://b/" << bug_number; + } else { + ASSERT_EQ(result, kTfLiteOk) + << "Possibly due to http://b/" << bug_number; + } + } + } +} + +// Instantiate a test. This assumes `zip_base`.zip is a declared data file +// of this test. +#define INSTANTIATE_TESTS(zip_base) \ + INSTANTIATE_TEST_CASE_P( \ + zip_base, OpsTest, \ + ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip"))); + +INSTANTIATE_TESTS(add) +INSTANTIATE_TESTS(avg_pool) +INSTANTIATE_TESTS(concat) +INSTANTIATE_TESTS(constant) +INSTANTIATE_TESTS(control_dep) +INSTANTIATE_TESTS(conv) +INSTANTIATE_TESTS(depthwiseconv) +INSTANTIATE_TESTS(fully_connected) +INSTANTIATE_TESTS(fused_batch_norm) +INSTANTIATE_TESTS(global_batch_norm) +INSTANTIATE_TESTS(l2norm) +INSTANTIATE_TESTS(l2_pool) +INSTANTIATE_TESTS(local_response_norm) +INSTANTIATE_TESTS(max_pool) +INSTANTIATE_TESTS(mul) +INSTANTIATE_TESTS(relu) +INSTANTIATE_TESTS(relu1) +INSTANTIATE_TESTS(relu6) +INSTANTIATE_TESTS(reshape) +INSTANTIATE_TESTS(resize_bilinear) +INSTANTIATE_TESTS(sigmoid) +INSTANTIATE_TESTS(softmax) +INSTANTIATE_TESTS(space_to_depth) + +} // namespace testing +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment()); + + std::vector flags = {tensorflow::Flag( + "ignore_known_bugs", &FLAGS_ignore_known_bugs, + "If a particular model is affected by a known bug, the " + "corresponding test should expect the outputs to not match.")}; + bool success = tensorflow::Flags::Parse(&argc, argv, flags); + if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return 1; + } + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/contrib/lite/testing/message.cc new file mode 100644 index 0000000000..03fae4bb86 --- /dev/null +++ b/tensorflow/contrib/lite/testing/message.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/message.h" + +#include + +#include "tensorflow/contrib/lite/testing/tokenize.h" + +namespace tflite { +namespace testing { + +// A token processor that builds messages and forward calls to the current +// message object. Place a new message at the top of the stack when it start +// and remove it when it is finished. +class MessageStack : public TokenProcessor { + public: + // Start a new MessageStack with the given first_node, which will be used to + // process freestanding fields and submessages. + explicit MessageStack(Message* first_node) { + nodes_.push(first_node); + valid_ = true; + } + + void ConsumeToken(std::string* token) override { + if (!valid_) return; + Message* current_node = nodes_.top(); + if (*token == "{") { + // This is the beginning of a new message, names after the previous token. + if (previous_token_.empty()) { + valid_ = false; + return; + } + nodes_.push(current_node ? current_node->AddChild(previous_token_) + : nullptr); + previous_token_.clear(); + } else if (*token == "}") { + // A message is being completed. There should be no previous token. Note + // that the top-level message never closes, so we should always have at + // least one entry in the stack. + if (nodes_.size() == 1 || !previous_token_.empty()) { + valid_ = false; + return; + } + if (current_node) { + current_node->Finish(); + } + nodes_.pop(); + } else if (*token == ":") { + // We reached the end of the 'key' portion of a field. Store the token + // until we have the 'value' portion. + if (previous_token_.empty()) { + valid_ = false; + return; + } + } else { + if (previous_token_.empty()) { + previous_token_.swap(*token); + } else { + // This is the 'value' portion of a field. The previous token is the + // 'key'. + if (current_node) { + current_node->SetField(previous_token_, *token); + } + previous_token_.clear(); + } + } + } + + bool valid() const { return valid_; } + + private: + std::stack nodes_; + std::string previous_token_; + bool valid_; +}; + +bool Message::Read(std::istream* input, Message* message) { + MessageStack stack(message); + Tokenize(input, &stack); + return stack.valid(); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h new file mode 100644 index 0000000000..78ef7e2cbe --- /dev/null +++ b/tensorflow/contrib/lite/testing/message.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ + +#include +#include +#include + +namespace tflite { +namespace testing { + +// A Message is a textual protobuf-like structure that looks like: +// tag { +// f : "values" +// child { +// a : 1 +// } +// } +// This class provides the framework for processing message but does not +// associate any particular behavior to fields and submessage. In order +// to properly parse a stream this class must be derived. +class Message { + public: + // Reads a stream, tokenizes it and create a new message under the given + // top-level message. Returns true if the parsing succeeded. + static bool Read(std::istream* input, Message* message); + + Message() {} + virtual ~Message() {} + + // Called when a new field is found. For example, when: + // f : "values" + // is found, it triggers: + // SetField("f", "values"); + virtual void SetField(const std::string& name, const std::string& value) {} + + // Called when a submessage is started. For example, when: + // child { + // is found, it triggers + // AddChild("child"); + // If nullptr is returned, the contents of the submessage will be ignored. + // Otherwise, the returned Message will be used to handle new fields and new + // submessages. The caller should not take ownership of the returned pointer. + virtual Message* AddChild(const std::string& name) { return nullptr; } + + // Called when a submessage is completed, that is, whenever a '}' is found. + virtual void Finish() {} + + protected: + // Takes ownership of the given pointer. Subclasses can use this method if + // they don't want to implement their own ownership semantics. + Message* Store(Message* n) { + children_.emplace_back(n); + return n; + } + + // Returns a list of all owned submessages. + const std::vector>& Children() const { + return children_; + } + + private: + std::vector> children_; +}; + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/contrib/lite/testing/message_test.cc new file mode 100644 index 0000000000..fb6a49bd6f --- /dev/null +++ b/tensorflow/contrib/lite/testing/message_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/message.h" + +#include + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +// A hierarchical, key-value store. +class TestMessage : public Message { + public: + TestMessage() {} + explicit TestMessage(const std::string& text_to_parse) { + std::stringstream ss(text_to_parse); + finished_ = Message::Read(&ss, this); + } + void SetField(const std::string& name, const std::string& value) override { + fields_[name] = value; + } + Message* AddChild(const std::string& name) override { + TestMessage* m = new TestMessage; + m->name_ = name; + return Store(m); + } + void Finish() override { finished_ = true; } + + int NumChildren() const { return Children().size(); } + + const TestMessage* GetChild(int i) const { + return dynamic_cast(Children()[i].get()); + } + + int NumFields() const { return fields_.size(); } + const std::string& GetField(const std::string& key) const { + return fields_.at(key); + } + + const std::string& name() const { return name_; } + bool finished() const { return finished_; } + + protected: + std::string name_; + std::map fields_; + bool finished_ = false; +}; + +TEST(MessageTest, Simple) { + TestMessage message("x{a:1 b:2} y{} z{c:3} d:4"); + ASSERT_TRUE(message.finished()); + + ASSERT_EQ(message.NumFields(), 1); + EXPECT_EQ(message.GetField("d"), "4"); + + ASSERT_EQ(message.NumChildren(), 3); + + auto* x = message.GetChild(0); + EXPECT_EQ(x->name(), "x"); + ASSERT_EQ(x->NumFields(), 2); + EXPECT_EQ(x->GetField("a"), "1"); + EXPECT_EQ(x->GetField("b"), "2"); + + auto* y = message.GetChild(1); + EXPECT_EQ(y->name(), "y"); + ASSERT_EQ(y->NumFields(), 0); + + auto* z = message.GetChild(2); + EXPECT_EQ(z->name(), "z"); + ASSERT_EQ(z->NumFields(), 1); + EXPECT_EQ(z->GetField("c"), "3"); +} + +TEST(MessageTest, Unnamed) { + TestMessage message("x{c:3} {} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 1); +} + +TEST(MessageTest, TooManyBraces) { + TestMessage message("x{c:3} } y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 1); +} + +TEST(MessageTest, LeftoverToken) { + TestMessage message("x{c:3} z{test} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +TEST(MessageTest, MissingKey) { + TestMessage message("x{c:3} z{:test} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +TEST(MessageTest, MissingValue) { + TestMessage message("x{c:3} z{test:} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/contrib/lite/testing/nnapi_example.cc new file mode 100644 index 0000000000..74f6cfc3de --- /dev/null +++ b/tensorflow/contrib/lite/testing/nnapi_example.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// NOTE: this is an example driver that converts a tflite model to TensorFlow. +// This is an example that will be integrated more tightly into tflite in +// the future. +// +// Usage: bazel run -c opt \ +// tensorflow/contrib/lite/nnapi:nnapi_example -- +// +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" + +// TODO(aselle): FATAL leaves resources hanging. +void FATAL(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fflush(stderr); + exit(1); +} + +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +void Interpret(const char* filename, const char* examples_filename, + bool use_nnapi) { + // TODO(aselle): Resize of input image should go here + // ... + // For now I am allocating all tensors. This means I am fixed size. + // So I am not using the variable size ability yet. + fprintf(stderr, "example file %s\n", examples_filename); + std::vector examples; + CHECK_TFLITE_SUCCESS( + tflite::testing::ParseExamples(examples_filename, &examples)); + + for (const tflite::testing::Example& example : examples) { + auto model = tflite::FlatBufferModel::BuildFromFile(filename); + if (!model) FATAL("Cannot read file %s\n", filename); + std::unique_ptr interpreter; + tflite::ops::builtin::BuiltinOpResolver builtins; + + CHECK_TFLITE_SUCCESS( + tflite::InterpreterBuilder(*model, builtins)(&interpreter)); + + printf("Use nnapi is set to: %d\n", use_nnapi); + interpreter->UseNNAPI(use_nnapi); + CHECK_TFLITE_SUCCESS( + tflite::testing::FeedExample(interpreter.get(), example)); + + { + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + *p = 0; + } + } + } + interpreter->Invoke(); + + CHECK_TFLITE_SUCCESS( + tflite::testing::CheckOutputs(interpreter.get(), example)); + + printf("Result:\n"); + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + printf(" %f", *p); + } + } + } +} + +int main(int argc, char* argv[]) { + bool use_nnapi = true; + if (argc == 4) { + use_nnapi = strcmp(argv[3], "1") == 0 ? true : false; + } + if (argc < 3) { + fprintf(stderr, + "Compiled " __DATE__ __TIME__ + "\n" + "Usage!!!: %s " + "{ use nn api i.e. 0,1}\n", + argv[0]); + return 1; + } + Interpret(argv[1], argv[2], use_nnapi); + return 0; +} diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc new file mode 100644 index 0000000000..2b67052cad --- /dev/null +++ b/tensorflow/contrib/lite/testing/parse_testdata.cc @@ -0,0 +1,335 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Parses tflite example input data. +// Format is ASCII +// TODO(aselle): Switch to protobuf, but the android team requested a simple +// ASCII file. +#include "tensorflow/contrib/lite/testing/parse_testdata.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/message.h" +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { +namespace { + +// Fatal error if parse error occurs +#define PARSE_CHECK_EQ(filename, current_line, x, y) \ + if ((x) != (y)) { \ + fprintf(stderr, "Parse Error @ %s:%d\n File %s\n Line %d, %s != %s\n", \ + __FILE__, __LINE__, filename, current_line + 1, #x, #y); \ + return kTfLiteError; \ + } + +// Breakup a "," delimited line into a std::vector. +// This is extremely inefficient, and just used for testing code. +// TODO(aselle): replace with absl when we use it. +std::vector ParseLine(const std::string& line) { + size_t pos = 0; + std::vector elements; + while (true) { + size_t end = line.find(',', pos); + if (end == std::string::npos) { + elements.push_back(line.substr(pos)); + break; + } else { + elements.push_back(line.substr(pos, end - pos)); + } + pos = end + 1; + } + return elements; +} + +} // namespace + +// Given a `filename`, produce a vector of Examples corresopnding +// to test cases that can be applied to a tflite model. +TfLiteStatus ParseExamples(const char* filename, + std::vector* examples) { + std::ifstream fp(filename); + if (!fp.good()) { + fprintf(stderr, "Could not read '%s'\n", filename); + return kTfLiteError; + } + std::string str((std::istreambuf_iterator(fp)), + std::istreambuf_iterator()); + size_t pos = 0; + + // \n and , delimit parse a file. + std::vector> csv; + while (true) { + size_t end = str.find('\n', pos); + + if (end == std::string::npos) { + csv.emplace_back(ParseLine(str.substr(pos))); + break; + } + csv.emplace_back(ParseLine(str.substr(pos, end - pos))); + pos = end + 1; + } + + int current_line = 0; + PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases"); + int example_count = std::stoi(csv[0][1]); + current_line++; + + auto parse_tensor = [&filename, ¤t_line, + &csv](FloatTensor* tensor_ptr) { + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype"); + current_line++; + // parse shape + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape"); + size_t elements = 1; + FloatTensor& tensor = *tensor_ptr; + + for (size_t i = 1; i < csv[current_line].size(); i++) { + const auto& shape_part_to_parse = csv[current_line][i]; + if (shape_part_to_parse.empty()) { + // Case of a 0-dimensional shape + break; + } + int shape_part = std::stoi(shape_part_to_parse); + elements *= shape_part; + tensor.shape.push_back(shape_part); + } + current_line++; + // parse data + PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1, + elements); + for (size_t i = 1; i < csv[current_line].size(); i++) { + tensor.flat_data.push_back(std::stof(csv[current_line][i])); + } + current_line++; + + return kTfLiteOk; + }; + + for (int example_idx = 0; example_idx < example_count; example_idx++) { + Example example; + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs"); + int inputs = std::stoi(csv[current_line][1]); + current_line++; + // parse dtype + for (int input_index = 0; input_index < inputs; input_index++) { + example.inputs.push_back(FloatTensor()); + TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back())); + } + + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs"); + int outputs = std::stoi(csv[current_line][1]); + current_line++; + for (int input_index = 0; input_index < outputs; input_index++) { + example.outputs.push_back(FloatTensor()); + TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back())); + } + examples->emplace_back(example); + } + return kTfLiteOk; +} + +TfLiteStatus FeedExample(tflite::Interpreter* interpreter, + const Example& example) { + // Resize inputs to match example & allocate. + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input_index = interpreter->inputs()[i]; + + TF_LITE_ENSURE_STATUS( + interpreter->ResizeInputTensor(input_index, example.inputs[i].shape)); + } + TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors()); + // Copy data into tensors. + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input_index = interpreter->inputs()[i]; + if (float* data = interpreter->typed_tensor(input_index)) { + for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) { + data[idx] = example.inputs[i].flat_data[idx]; + } + } else if (int32_t* data = + interpreter->typed_tensor(input_index)) { + for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) { + data[idx] = example.inputs[i].flat_data[idx]; + } + } else { + fprintf(stderr, "input[%zu] was not float or int data\n", i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, + const Example& example) { + constexpr double kRelativeThreshold = 1e-2f; + constexpr double kAbsoluteThreshold = 1e-4f; + + ErrorReporter* context = DefaultErrorReporter(); + int model_outputs = interpreter->outputs().size(); + TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size()); + for (size_t i = 0; i < interpreter->outputs().size(); i++) { + int output_index = interpreter->outputs()[i]; + if (const float* data = interpreter->typed_tensor(output_index)) { + for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { + float computed = data[idx]; + float reference = example.outputs[0].flat_data[idx]; + float diff = std::abs(computed - reference); + bool error_is_large = false; + // For very small numbers, try absolute error, otherwise go with + // relative. + if (std::abs(reference) < kRelativeThreshold) { + error_is_large = (diff > kAbsoluteThreshold); + } else { + error_is_large = (diff > kRelativeThreshold * std::abs(reference)); + } + if (error_is_large) { + fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n", + i, idx, data[idx], reference); + return kTfLiteError; + } + } + fprintf(stderr, "\n"); + } else if (const int32_t* data = + interpreter->typed_tensor(output_index)) { + for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { + int32_t computed = data[idx]; + int32_t reference = example.outputs[0].flat_data[idx]; + if (std::abs(computed - reference) > 0) { + fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n", + i, idx, data[idx], example.outputs[0].flat_data[idx]); + return kTfLiteError; + } + } + fprintf(stderr, "\n"); + } else { + fprintf(stderr, "output[%zu] was not float or int data\n", i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +// Process an 'invoke' message, triggering execution of the test runner, as +// well as verification of outputs. An 'invoke' message looks like: +// invoke { +// id: xyz +// input: 1,2,1,1,1,2,3,4 +// ouput: 4,5,6 +// } +class Invoke : public Message { + public: + explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) { + expected_inputs_ = test_runner->GetInputs(); + expected_outputs_ = test_runner->GetOutputs(); + } + + void SetField(const std::string& name, const std::string& value) override { + if (name == "id") { + test_runner_->SetInvocationId(value); + } else if (name == "input") { + if (expected_inputs_.empty()) { + return test_runner_->Invalidate("Too many inputs"); + } + test_runner_->SetInput(*expected_inputs_.begin(), value); + expected_inputs_.erase(expected_inputs_.begin()); + } else if (name == "output") { + if (expected_outputs_.empty()) { + return test_runner_->Invalidate("Too many outputs"); + } + test_runner_->SetExpectation(*expected_outputs_.begin(), value); + expected_outputs_.erase(expected_outputs_.begin()); + } + } + void Finish() override { + test_runner_->Invoke(); + test_runner_->CheckResults(); + } + + private: + std::vector expected_inputs_; + std::vector expected_outputs_; + + TestRunner* test_runner_; +}; + +// Process an 'reshape' message, triggering resizing of the input tensors via +// the test runner. A 'reshape' message looks like: +// reshape { +// input: 1,2,1,1,1,2,3,4 +// } +class Reshape : public Message { + public: + explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) { + expected_inputs_ = test_runner->GetInputs(); + } + + void SetField(const std::string& name, const std::string& value) override { + if (name == "input") { + if (expected_inputs_.empty()) { + return test_runner_->Invalidate("Too many inputs to reshape"); + } + test_runner_->ReshapeTensor(*expected_inputs_.begin(), value); + expected_inputs_.erase(expected_inputs_.begin()); + } + } + + private: + std::vector expected_inputs_; + TestRunner* test_runner_; +}; + +// This is the top-level message in a test file. +class TestData : public Message { + public: + explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {} + + void SetField(const std::string& name, const std::string& value) override { + if (name == "load_model") { + test_runner_->LoadModel(value); + } else if (name == "init_state") { + test_runner_->AllocateTensors(); + for (int id : Split(value, ",")) { + test_runner_->ResetTensor(id); + } + } + } + Message* AddChild(const std::string& s) override { + if (s == "invoke") { + test_runner_->AllocateTensors(); + return Store(new Invoke(test_runner_)); + } else if (s == "reshape") { + return Store(new Reshape(test_runner_)); + } + return nullptr; + } + + private: + TestRunner* test_runner_; +}; + +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) { + TestData test_data(test_runner); + Message::Read(input, &test_data); + return test_runner->IsValid() && test_runner->GetOverallSuccess(); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h new file mode 100644 index 0000000000..90839fe245 --- /dev/null +++ b/tensorflow/contrib/lite/testing/parse_testdata.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" + +namespace tflite { +namespace testing { + +// Shape and data for a float tensor +struct FloatTensor { + std::vector shape; + std::vector flat_data; +}; + +// A prescribed input, output example +struct Example { + std::vector inputs; + std::vector outputs; +}; + +// Parses an example input and output file (used for unit tests) +TfLiteStatus ParseExamples(const char* filename, + std::vector* examples); + +// Inputs Tensors into a TensorFlow lite interpreter. Note, this will run +// interpreter.AllocateTensors(); +TfLiteStatus FeedExample(tflite::Interpreter* interpreter, const Example&); + +// Check outputs against (already) evaluated result. +TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&); + +// Parses a test description and feeds the given test runner with data. +// The input format is similar to an ASCII proto: +// // Loads model 'add.bin' from the TestRunner's model directory. +// load_model: "add.bin" +// // Changes the shape of inputs, provided in the same order they appear +// // in the model. +// reshape { +// input: "1,224,224,3" +// input: "1,3,4,1" +// } +// // Fills the given persistent tensors with zeros. +// init_state: 0,1,2,3 +// // Invokes the interpreter with the given input and checks that it +// // produces the expected output. Inputs and outputs should be specified in +// // the order they appear in the model. +// invoke { +// input: "1,2,3,4,56" +// input: "0.1,0.2,0.3,4.3,56.4" +// output: "12,3,4,545,3" +// output: "0.01,0.02" +// } +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner); + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/contrib/lite/testing/split.cc new file mode 100644 index 0000000000..5836f4ff04 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { + +std::vector> SplitToPos(const string& s, + const string& delimiter) { + std::vector> fields; + if (delimiter.length() == 0) { + fields.emplace_back(0, s.length()); + return fields; + } + size_t pos = 0; + size_t start = 0; + while ((pos = s.find(delimiter, start)) != string::npos) { + if (pos != start) { + fields.emplace_back(start, pos); + } + start = pos + delimiter.length(); + } + if (start != s.length()) { + fields.emplace_back(start, s.length()); + } + return fields; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h new file mode 100644 index 0000000000..24071442e8 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ + +#include +#include +#include +#include +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// Splits a string based on the given delimiter string. Each pair in the +// returned vector has the start and past-the-end positions for each of the +// parts of the original string. Empty fields are not represented in the +// output. +std::vector> SplitToPos(const string& s, + const string& delimiter); + +// Splits the given string and converts each part to the given T. +template +std::vector Split(const string& s, const string& delimiter); + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(s.substr(p.first, p.second - p.first)); + } + return fields; +} + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtod(s.data() + p.first, nullptr)); + } + return fields; +} + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc new file mode 100644 index 0000000000..3d1e25d9c7 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/split.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; +using ::testing::Pair; + +TEST(SplitTest, SplitToPos) { + EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ";:"), + ElementsAre(Pair(0, 4), Pair(6, 12), Pair(14, 19))); + EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ":"), + ElementsAre(Pair(0, 5), Pair(6, 13), Pair(14, 19))); + EXPECT_THAT(SplitToPos("test", ":"), ElementsAre(Pair(0, 4))); + EXPECT_THAT(SplitToPos("test ", ":"), ElementsAre(Pair(0, 5))); + EXPECT_THAT(SplitToPos("", ":"), ElementsAre()); + EXPECT_THAT(SplitToPos("test ", ""), ElementsAre(Pair(0, 5))); + EXPECT_THAT(SplitToPos("::::", ":"), ElementsAre()); +} + +TEST(SplitTest, SplitString) { + EXPECT_THAT(Split("A;B;C", ";"), ElementsAre("A", "B", "C")); +} + +TEST(SplitTest, SplitFloat) { + EXPECT_THAT(Split("1.0 B 1e-5", " "), ElementsAre(1.0, 0.0, 1e-5)); +} + +TEST(SplitTest, SplitInt) { + EXPECT_THAT(Split("1,-1,258", ","), ElementsAre(1, -1, 258)); +} + +TEST(SplitTest, SplitUint8) { + EXPECT_THAT(Split("1,-1,258", ","), ElementsAre(1, 255, 2)); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h new file mode 100644 index 0000000000..04ee4d9f7d --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// This is the base class for processing test data. Each one of the virtual +// methods must be implemented to forward the data to the appropriate executor +// (e.g. TF Lite's interpreter, or the NNAPI). +class TestRunner { + public: + TestRunner() {} + virtual ~TestRunner() {} + + // Load the given model, as a path relative to SetModelBaseDir(). + virtual void LoadModel(const string& bin_file_path) = 0; + + // Return the list of input tensors in the loaded model. + virtual const std::vector& GetInputs() = 0; + + // Return the list of output tensors in the loaded model. + virtual const std::vector& GetOutputs() = 0; + + // Prepare for a run by resize the given tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void ReshapeTensor(int id, const string& csv_values) = 0; + + // Reserve memory for all tensors. + virtual void AllocateTensors() = 0; + + // Set the given tensor to some initial state, usually zero. This is + // used to reset persistent buffers in a model. + virtual void ResetTensor(int id) = 0; + + // Define the contents of the given input tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void SetInput(int id, const string& csv_values) = 0; + + // Define what should be expected for an output tensor after Invoke() runs. + // The given 'id' is guaranteed to be one of the ids returned by + // GetOutputs(). + virtual void SetExpectation(int id, const string& csv_values) = 0; + + // Run the model. + virtual void Invoke() = 0; + + // Verify that the contents of all ouputs conform to the existing + // expectations. Return true if there are no expectations or they are all + // satisfied. + virtual bool CheckResults() = 0; + + // Set the base path for loading models. + void SetModelBaseDir(const string& path) { + model_base_dir_ = path; + if (path[path.length() - 1] != '/') { + model_base_dir_ += "/"; + } + } + + // Return the full path of a model. + string GetFullPath(const string& path) { return model_base_dir_ + path; } + + // Give an id to the next invocation to make error reporting more meaningful. + void SetInvocationId(const string& id) { invocation_id_ = id; } + const string& GetInvocationId() const { return invocation_id_; } + + // Invalidate the test runner, preventing it from executing any further. + void Invalidate(const string& error_message) { + error_message_ = error_message; + } + bool IsValid() const { return error_message_.empty(); } + const string& GetErrorMessage() const { return error_message_; } + + // Handle the overall success of this test runner. This will be true if all + // invocations were successful. + void SetOverallSuccess(bool value) { overall_success_ = value; } + bool GetOverallSuccess() const { return overall_success_; } + + protected: + // A helper to check of the given number of values is consistent with the + // number of bytes in a tensor of type T. When incompatibles sizes are found, + // the test runner is invalidated and false is returned. + template + bool CheckSizes(size_t tensor_bytes, size_t num_values) { + size_t num_tensor_elements = tensor_bytes / sizeof(T); + if (num_tensor_elements != num_values) { + Invalidate("Expected '" + std::to_string(num_tensor_elements) + + "' elements for a tensor, but only got '" + + std::to_string(num_values) + "'"); + return false; + } + return true; + } + + private: + string model_base_dir_; + string invocation_id_; + bool overall_success_ = true; + + string error_message_; +}; + +} // namespace testing +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc new file mode 100644 index 0000000000..f712a5347a --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/test_runner.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +class ConcreteTestRunner : public TestRunner { + public: + void LoadModel(const string& bin_file_path) override {} + const std::vector& GetInputs() override { return ids_; } + const std::vector& GetOutputs() override { return ids_; } + void ReshapeTensor(int id, const string& csv_values) override {} + void AllocateTensors() override {} + void ResetTensor(int id) override {} + void SetInput(int id, const string& csv_values) override {} + void SetExpectation(int id, const string& csv_values) override {} + void Invoke() override {} + bool CheckResults() override { return true; } + bool CheckFloatSizes(size_t bytes, size_t values) { + return CheckSizes(bytes, values); + } + + private: + std::vector ids_; +}; + +TEST(TestRunner, ModelPath) { + ConcreteTestRunner runner; + EXPECT_EQ(runner.GetFullPath("test.bin"), "test.bin"); + runner.SetModelBaseDir("/tmp"); + EXPECT_EQ(runner.GetFullPath("test.bin"), "/tmp/test.bin"); +} + +TEST(TestRunner, InvocationId) { + ConcreteTestRunner runner; + EXPECT_EQ(runner.GetInvocationId(), ""); + runner.SetInvocationId("X"); + EXPECT_EQ(runner.GetInvocationId(), "X"); +} + +TEST(TestRunner, Invalidation) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.IsValid()); + EXPECT_EQ(runner.GetErrorMessage(), ""); + runner.Invalidate("Some Error"); + EXPECT_FALSE(runner.IsValid()); + EXPECT_EQ(runner.GetErrorMessage(), "Some Error"); +} + +TEST(TestRunner, OverallSuccess) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.GetOverallSuccess()); + runner.SetOverallSuccess(false); + EXPECT_FALSE(runner.GetOverallSuccess()); +} + +TEST(TestRunner, CheckSizes) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.CheckFloatSizes(16, 4)); + EXPECT_FALSE(runner.CheckFloatSizes(16, 2)); + EXPECT_EQ(runner.GetErrorMessage(), + "Expected '4' elements for a tensor, but only got '2'"); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc new file mode 100644 index 0000000000..cf9df2ec26 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +#include + +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { + +namespace { + +// Returns the value in the given position in a tensor. +template +T Value(const TfLitePtrUnion& data, int index); +template <> +float Value(const TfLitePtrUnion& data, int index) { + return data.f[index]; +} +template <> +uint8_t Value(const TfLitePtrUnion& data, int index) { + return data.uint8[index]; +} + +template +void SetTensorData(const std::vector& values, TfLitePtrUnion* data) { + T* input_ptr = reinterpret_cast(data->raw); + for (const T& v : values) { + *input_ptr = v; + ++input_ptr; + } +} + +} // namespace + +class TfLiteDriver::Expectation { + public: + Expectation() { data_.raw = nullptr; } + ~Expectation() { delete[] data_.raw; } + template + void SetData(const string& csv_values) { + const auto& values = testing::Split(csv_values, ","); + data_.raw = new char[values.size() * sizeof(T)]; + SetTensorData(values, &data_); + } + + bool Check(bool verbose, const TfLiteTensor& tensor) { + switch (tensor.type) { + case kTfLiteFloat32: + return TypedCheck(verbose, tensor); + case kTfLiteUInt8: + return TypedCheck(verbose, tensor); + default: + return false; + } + } + + private: + template + bool TypedCheck(bool verbose, const TfLiteTensor& tensor) { + int tensor_size = tensor.bytes / sizeof(T); + + bool good_output = true; + for (int i = 0; i < tensor_size; ++i) { + if (std::abs(Value(data_, i) - Value(tensor.data, i)) > 1e-5) { + good_output = false; + if (verbose) { + std::cerr << " index " << i << ": " << Value(data_, i) + << " != " << Value(tensor.data, i) << std::endl; + } + } + } + return good_output; + } + + TfLitePtrUnion data_; +}; + +TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {} +TfLiteDriver::~TfLiteDriver() {} + +void TfLiteDriver::AllocateTensors() { + if (must_allocate_tensors_) { + if (interpreter_->AllocateTensors() != kTfLiteOk) { + std::cerr << "Failed to allocate tensors" << std::endl; + abort(); + } + must_allocate_tensors_ = false; + } +} + +void TfLiteDriver::LoadModel(const string& bin_file_path) { + if (!IsValid()) return; + std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; + + model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str()); + if (!model_) { + Invalidate("Failed to mmap model " + bin_file_path); + return; + } + ops::builtin::BuiltinOpResolver builtins; + InterpreterBuilder(*model_, builtins)(&interpreter_); + if (!interpreter_) { + Invalidate("Failed build interpreter"); + return; + } + + must_allocate_tensors_ = true; +} + +void TfLiteDriver::ResetTensor(int id) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + memset(tensor->data.raw, 0, tensor->bytes); +} + +void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) { + if (!IsValid()) return; + if (interpreter_->ResizeInputTensor( + id, testing::Split(csv_values, ",")) != kTfLiteOk) { + Invalidate("Failed to resize input tensor " + std::to_string(id)); + return; + } + must_allocate_tensors_ = true; +} + +void TfLiteDriver::SetInput(int id, const string& csv_values) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + switch (tensor->type) { + case kTfLiteFloat32: { + const auto& values = testing::Split(csv_values, ","); + if (!CheckSizes(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } + case kTfLiteUInt8: { + const auto& values = testing::Split(csv_values, ","); + if (!CheckSizes(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } + default: + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfLiteDriver::SetExpectation(int id, const string& csv_values) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + expected_output_[id].reset(new Expectation); + switch (tensor->type) { + case kTfLiteFloat32: + expected_output_[id]->SetData(csv_values); + break; + case kTfLiteUInt8: + expected_output_[id]->SetData(csv_values); + break; + default: + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfLiteDriver::Invoke() { + if (!IsValid()) return; + if (interpreter_->Invoke() != kTfLiteOk) { + Invalidate("Failed to invoke interpreter"); + } +} + +bool TfLiteDriver::CheckResults() { + if (!IsValid()) return false; + bool success = true; + for (const auto& p : expected_output_) { + int id = p.first; + auto* tensor = interpreter_->tensor(id); + if (!p.second->Check(/*verbose=*/false, *tensor)) { + // Do not invalidate anything here. Instead, simply output the + // differences and return false. Invalidating would prevent all + // subsequent invocations from running.. + std::cerr << "There were errors in invocation '" << GetInvocationId() + << "', output tensor '" << id << "':" << std::endl; + p.second->Check(/*verbose=*/true, *tensor); + success = false; + SetOverallSuccess(false); + } + } + expected_output_.clear(); + return success; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h new file mode 100644 index 0000000000..4440d4285e --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ + +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" + +namespace tflite { +namespace testing { + +// A test runner that feeds inputs into TF Lite and verifies its outputs. +class TfLiteDriver : public TestRunner { + public: + explicit TfLiteDriver(bool use_nnapi); + ~TfLiteDriver() override; + + void LoadModel(const string& bin_file_path) override; + const std::vector& GetInputs() override { + return interpreter_->inputs(); + } + const std::vector& GetOutputs() override { + return interpreter_->outputs(); + } + void ReshapeTensor(int id, const string& csv_values) override; + void AllocateTensors() override; + void ResetTensor(int id) override; + void SetInput(int id, const string& csv_values) override; + void SetExpectation(int id, const string& csv_values) override; + void Invoke() override; + bool CheckResults() override; + + private: + class Expectation; + + bool use_nnapi_ = false; + std::unique_ptr model_; + std::unique_ptr interpreter_; + std::map> expected_output_; + bool must_allocate_tensors_ = true; +}; + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/contrib/lite/testing/tflite_driver_test.cc new file mode 100644 index 0000000000..79e8a86972 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; + +TEST(TfliteDriverTest, SimpleTest) { + std::unique_ptr runner(new TfLiteDriver(/*use_nnapi=*/false)); + + runner->SetModelBaseDir("third_party/tensorflow/contrib/lite"); + runner->LoadModel("testdata/multi_add.bin"); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->AllocateTensors(); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + + runner->ResetTensor(2); + + runner->SetExpectation(5, "0.101,0.202,0.303,0.404"); + runner->SetExpectation(6, "0.011,0.022,0.033,0.044"); + + runner->Invoke(); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_TRUE(runner->CheckResults()); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/contrib/lite/testing/tokenize.cc new file mode 100644 index 0000000000..2e84ea475c --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tokenize.h" +#include +#include +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +void Tokenize(std::istream* input, TokenProcessor* processor) { + enum State { kBuildQuotedToken, kBuildToken, kIdle }; + + std::string current_token; + State state = kIdle; + auto start_token = [&](char c) { + state = kBuildToken; + current_token.clear(); + current_token = c; + }; + auto issue_token = [&]() { + state = kIdle; + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto start_quoted_token = [&]() { + state = kBuildQuotedToken; + current_token.clear(); + }; + auto issue_quoted_token = [&]() { + state = kIdle; + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto issue_delim = [&](char d) { + current_token = string(1, d); + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto is_delim = [](char c) { return c == '{' || c == '}' || c == ':'; }; + auto is_quote = [](char c) { return c == '"'; }; + + for (auto it = std::istreambuf_iterator(*input); + it != std::istreambuf_iterator(); ++it) { + switch (state) { + case kIdle: + if (is_delim(*it)) { + issue_delim(*it); + } else if (is_quote(*it)) { + start_quoted_token(); + } else if (!isspace(*it)) { + start_token(*it); + } + break; + case kBuildToken: + if (is_delim(*it)) { + issue_token(); + issue_delim(*it); + } else if (is_quote(*it)) { + issue_token(); + start_quoted_token(); + } else if (isspace(*it)) { + issue_token(); + } else { + current_token += *it; + } + break; + case kBuildQuotedToken: + if (is_quote(*it)) { + issue_quoted_token(); + } else { + current_token += *it; + } + break; + } + } + if (state != kIdle) { + issue_token(); + } +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h new file mode 100644 index 0000000000..daccf0e84a --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ + +#include +#include + +namespace tflite { +namespace testing { + +// Process tokens coming from Tokenize(). +class TokenProcessor { + public: + virtual ~TokenProcessor() {} + // Process a single token. The token won't be reused, so it is OK to call + // token.swap(). + virtual void ConsumeToken(std::string* token) = 0; +}; + +// Tokenize a stream on whitespaces, colons and curly braces. Whitespaces are +// removed from the tokens and double-quotes can be used to avoid that. Note +// that there is no way to escape double-quotes, so there's no way to have a +// double-quote inside a token. +void Tokenize(std::istream* input, TokenProcessor* processor); + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/contrib/lite/testing/tokenize_test.cc new file mode 100644 index 0000000000..80f44aacca --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tokenize.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class TokenCollector : public TokenProcessor { + public: + void ConsumeToken(std::string* token) override { tokens_.push_back(*token); } + const std::vector& Tokens() { return tokens_; } + + private: + std::vector tokens_; +}; + +std::vector TokenizeString(const std::string& s) { + std::stringstream ss(s); + TokenCollector collector; + Tokenize(&ss, &collector); + return collector.Tokens(); +} + +TEST(TokenizeTest, TokenDetection) { + EXPECT_THAT(TokenizeString("x :1"), ElementsAre("x", ":", "1")); + EXPECT_THAT(TokenizeString("x:1"), ElementsAre("x", ":", "1")); + EXPECT_THAT(TokenizeString("x {1"), ElementsAre("x", "{", "1")); + EXPECT_THAT(TokenizeString("x{1"), ElementsAre("x", "{", "1")); + EXPECT_THAT(TokenizeString("x }1"), ElementsAre("x", "}", "1")); + EXPECT_THAT(TokenizeString("x}1"), ElementsAre("x", "}", "1")); + EXPECT_THAT(TokenizeString("x \"1"), ElementsAre("x", "1")); + EXPECT_THAT(TokenizeString("x\"1"), ElementsAre("x", "1")); +} + +TEST(TokenizeTest, QuotedTokenDetection) { + EXPECT_THAT(TokenizeString("\"w:x{y}z\"1"), ElementsAre("w:x{y}z", "1")); + EXPECT_THAT(TokenizeString("\"w:x{y}z\"\"1\""), ElementsAre("w:x{y}z", "1")); +} + +TEST(TokenizeTest, Delimiters) { + EXPECT_THAT(TokenizeString("}"), ElementsAre("}")); + EXPECT_THAT(TokenizeString("}}"), ElementsAre("}", "}")); + EXPECT_THAT(TokenizeString("{"), ElementsAre("{")); + EXPECT_THAT(TokenizeString("{{"), ElementsAre("{", "{")); + EXPECT_THAT(TokenizeString(":"), ElementsAre(":")); + EXPECT_THAT(TokenizeString("::"), ElementsAre(":", ":")); +} + +TEST(TokenizeTest, CornerCases) { + EXPECT_THAT(TokenizeString(" i { b:a } "), + ElementsAre("i", "{", "b", ":", "a", "}")); + EXPECT_THAT(TokenizeString(" }"), ElementsAre("}")); + EXPECT_THAT(TokenizeString(" } "), ElementsAre("}")); + EXPECT_THAT(TokenizeString(" {} "), ElementsAre("{", "}")); + EXPECT_THAT(TokenizeString(" x{} y{} "), + ElementsAre("x", "{", "}", "y", "{", "}")); + EXPECT_THAT(TokenizeString("x:1 y:2 "), + ElementsAre("x", ":", "1", "y", ":", "2")); + EXPECT_THAT(TokenizeString("x:\"1\" y:2 "), + ElementsAre("x", ":", "1", "y", ":", "2")); + EXPECT_THAT(TokenizeString("x:\"1, 2\" y:\"\" "), + ElementsAre("x", ":", "1, 2", "y", ":", "")); +} + +TEST(TokenizeTest, NewLines) { + EXPECT_THAT(TokenizeString("x:\n1,\n 2 \n y :\n3 \n"), + ElementsAre("x", ":", "1,", "2", "y", ":", "3")); +} + +TEST(TokenizeTest, LongString) { + EXPECT_THAT( + TokenizeString(" i { b:a } input {" + "a: \"1e-1, 2,3\" b:\"1,2,3\"\n c{ " + "id:1 x{d{a:" + "1}}} f:2 " + "\n}\n t:1"), + ElementsAreArray({"i", "{", "b", ":", "a", "}", "input", "{", + "a", ":", "1e-1, 2,3", "b", ":", "1,2,3", "c", "{", + "id", ":", "1", "x", "{", "d", "{", "a", + ":", "1", "}", "}", "}", "f", ":", "2", + "}", "t", ":", "1"})); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD new file mode 100644 index 0000000000..1c73ab8f4a --- /dev/null +++ b/tensorflow/contrib/lite/toco/BUILD @@ -0,0 +1,350 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_cc", + "tf_proto_library_py", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_binary", + "tf_cc_test", +) + +tf_proto_library_cc( + name = "toco_flags_proto", + srcs = ["toco_flags.proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_cc( + name = "model_flags_proto", + srcs = ["model_flags.proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_py( + name = "toco_flags_proto", + srcs = [ + "toco_flags.proto", + ], + visibility = ["//visibility:public"], +) + +tf_proto_library_py( + name = "model_flags_proto", + srcs = [ + "model_flags.proto", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "tensorflow_core_cc_protos_all", + deps = ["//tensorflow/core:protos_all_cc"], +) + +cc_library( + name = "runtime", + hdrs = [ + "runtime/common.h", + "runtime/types.h", + ], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:types", + ], +) + +# :model offers the core data structures representing a model (a.k.a. "graph") +# for tooling purposes (not needed at inference runtime). +# That includes the top-level Model structure, and the lower-level Operator, +# Array, Buffer structures, etc. +cc_library( + name = "model", + hdrs = [ + "model.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_flags_proto_cc", + ":runtime", + ":toco_port", + "//tensorflow/core:lib", + "@com_google_absl//absl/base:core_headers", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( + name = "toco_graphviz_dump_options", + srcs = [ + "toco_graphviz_dump_options.cc", + ], + hdrs = [ + "toco_graphviz_dump_options.h", + ], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "toco_cmdline_flags", + srcs = [ + "toco_cmdline_flags.cc", + ], + hdrs = [ + "toco_cmdline_flags.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_cmdline_flags", + ":toco_flags_proto_cc", + ":toco_port", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "model_cmdline_flags", + srcs = [ + "model_cmdline_flags.cc", + ], + hdrs = [ + "args.h", + "model_cmdline_flags.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_flags_proto_cc", + ":toco_graphviz_dump_options", + ":toco_port", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "toco_port", + srcs = [ + "toco_port.cc", + ], + hdrs = [ + "format_port.h", + "toco_port.h", + "toco_types.h", + ], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ] + select({ + "//tensorflow:android": [], + "//tensorflow:darwin": [], + "//tensorflow:ios": [], + "//conditions:default": [], + "//tensorflow:dummy_disabled_internal": [], + }), +) + +cc_library( + name = "graph_transformations", + srcs = [ + "graph_transformations/convert_pure_conv_to_depthwise.cc", + "graph_transformations/create_im2col_arrays.cc", + "graph_transformations/dequantize.cc", + "graph_transformations/drop_fake_quant.cc", + "graph_transformations/drop_im2col_arrays.cc", + "graph_transformations/ensure_bias_vectors.cc", + "graph_transformations/fuse_activation_functions.cc", + "graph_transformations/fuse_binary_into_following_affine.cc", + "graph_transformations/fuse_binary_into_preceding_affine.cc", + "graph_transformations/graph_transformations.cc", + "graph_transformations/hardcode_min_max.cc", + "graph_transformations/identify_l2_normalization.cc", + "graph_transformations/identify_l2_pool.cc", + "graph_transformations/identify_lstm.cc", + "graph_transformations/identify_relu1.cc", + "graph_transformations/make_initial_dequantize_operator.cc", + "graph_transformations/propagate_array_data_types.cc", + "graph_transformations/propagate_fixed_sizes.cc", + "graph_transformations/quantize.cc", + "graph_transformations/read_fake_quant_min_max.cc", + "graph_transformations/remove_final_dequantize_op.cc", + "graph_transformations/remove_tensorflow_assert.cc", + "graph_transformations/remove_tensorflow_identity.cc", + "graph_transformations/remove_trivial_binary.cc", + "graph_transformations/remove_trivial_concatenation.cc", + "graph_transformations/remove_trivial_concatenation_input.cc", + "graph_transformations/remove_trivial_passthrough.cc", + "graph_transformations/remove_trivial_passthrough.h", + "graph_transformations/remove_trivial_quantized_activation_func.cc", + "graph_transformations/remove_trivial_reshape.cc", + "graph_transformations/remove_unused_op.cc", + "graph_transformations/resolve_batch_normalization.cc", + "graph_transformations/resolve_constant_binary.cc", + "graph_transformations/resolve_constant_concatenation.cc", + "graph_transformations/resolve_constant_fake_quant.cc", + "graph_transformations/resolve_constant_tensorflow_shape.cc", + "graph_transformations/resolve_constant_unary.cc", + "graph_transformations/resolve_mean_attributes.cc", + "graph_transformations/resolve_pad_attributes.cc", + "graph_transformations/resolve_reorder_axes.cc", + "graph_transformations/resolve_reshape_attributes.cc", + "graph_transformations/resolve_slice_attributes.cc", + "graph_transformations/resolve_strided_slice_attributes.cc", + "graph_transformations/resolve_tensorflow_concat.cc", + "graph_transformations/resolve_tensorflow_matmul.cc", + "graph_transformations/resolve_tensorflow_merge.cc", + "graph_transformations/resolve_tensorflow_squeeze.cc", + "graph_transformations/resolve_tensorflow_switch.cc", + "graph_transformations/resolve_tensorflow_tile.cc", + "graph_transformations/unfuse_activation_functions.cc", + ], + hdrs = [ + "graph_transformations/graph_transformations.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_flags_proto_cc", + ":runtime", + ":toco_port", + ":tooling_util", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +# :toco_tooling is the library providing the offline tooling functionality +# exposed by the :toco command-line tool. +cc_library( + name = "toco_tooling", + srcs = [ + "allocate_transient_arrays.cc", + "export_tensorflow.cc", + "import_tensorflow.cc", + "tensorflow_util.cc", + "toco_tooling.cc", + ], + hdrs = [ + "allocate_transient_arrays.h", + "export_tensorflow.h", + "import_tensorflow.h", + "tensorflow_util.h", + "toco_tooling.h", + ], + copts = select({ + "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":graph_transformations", + ":model_flags_proto_cc", + ":toco_flags_proto_cc", + ":model", + ":runtime", + "//tensorflow/core:protos_all_cc", + ":toco_port", + ":tooling_util", + ":toco_graphviz_dump_options", + "@protobuf_archive//:protobuf_headers", + "//tensorflow/contrib/lite/toco/tflite:export", + "//tensorflow/contrib/lite/toco/tflite:import", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/memory", + "//tensorflow/core:lib", + "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:resolve_cluster", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + ] + select({ + # Placeholder for internal darwin rule. + "//conditions:default": [], + }), +) + +cc_library( + name = "tooling_util", + srcs = [ + "dump_graphviz.cc", + "tooling_util.cc", + ], + hdrs = [ + "dump_graphviz.h", + "tooling_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_flags_proto_cc", + ":runtime", + ":toco_flags_proto_cc", + ":toco_graphviz_dump_options", + ":toco_port", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "tooling_util_test", + srcs = ["tooling_util_test.cc"], + deps = [ + ":model", + ":tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + +# :toco is the main public command-line tool exposing the functionality +# of the :toco_tooling library. +tf_cc_binary( + name = "toco", + srcs = ["toco.cc"], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_cmdline_flags", + ":toco_flags_proto_cc", + ":toco_port", + ":toco_tooling", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "toco_port_test", + srcs = ["toco_port_test.cc"], + data = [ + "toco_port_test.cc", + ], + deps = [ + ":toco_port", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc new file mode 100644 index 0000000000..2f4454d7c8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -0,0 +1,318 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace { + +// The life span of an array. +struct ArrayLifespan { + // If true, the array is persistent state (as in a RNN). In that case, + // its allocation is permanent and the first_op, last_op members are + // unused. (The term 'transient' is a misnomer and we should think in + // terms of 'workspace' instead). + bool persistent = false; + // Index of the first op addressing that array. The array must be allocated + // just before executing this op. + std::size_t first_op = 0; + // Index of the last op addressing that array. We want to deallocate the array + // immediately after executing this op. + std::size_t last_op = 0; +}; + +bool StartsAt(const ArrayLifespan& lifespan, std::size_t op_index) { + return !lifespan.persistent && lifespan.first_op == op_index; +} + +bool EndsAt(const ArrayLifespan& lifespan, std::size_t op_index) { + return !lifespan.persistent && lifespan.last_op == op_index; +} + +// Helper function for ComputeArrayLifespans: updates one ArrayLifespan for +// one array for one op. +void UpdateArrayLifespan( + const string& array_name, std::size_t op_index, + std::unordered_map* array_lifespans) { + if (array_lifespans->count(array_name)) { + auto& lifespan = array_lifespans->at(array_name); + if (!lifespan.persistent) { + lifespan.first_op = std::min(lifespan.first_op, op_index); + lifespan.last_op = std::max(lifespan.last_op, op_index); + } + } else { + ArrayLifespan lifespan; + lifespan.first_op = op_index; + lifespan.last_op = op_index; + (*array_lifespans)[array_name] = lifespan; + } +} + +// Computes the ArrayLifespan for each array. +void ComputeArrayLifespans( + const Model& model, + std::unordered_map* array_lifespans) { + CHECK(array_lifespans->empty()); + for (const auto& rnn_state : model.flags.rnn_states()) { + ArrayLifespan lifespan; + lifespan.persistent = true; + (*array_lifespans)[rnn_state.state_array()] = lifespan; + } + for (std::size_t op_index = 0; op_index < model.operators.size(); + op_index++) { + const auto& op = model.operators[op_index]; + for (const auto& input : op->inputs) { + UpdateArrayLifespan(input, op_index, array_lifespans); + } + for (const auto& output : op->outputs) { + UpdateArrayLifespan(output, op_index, array_lifespans); + } + } +} + +inline bool operator==(const Alloc& a, const Alloc& b) { + CHECK(a.start != b.start || a.end == b.end); + return a.start == b.start; +} + +// Helper to keep track of total allocation size and of currently live +// allocations, and containing the core allocation routine. +class Allocator { + public: + Allocator() : total_size_(0) {} + + // Core allocation routine. + void Allocate(std::size_t size, Alloc* result) { + // Naive algorithm: pick the first gap between live allocations, + // that is wide enough for the new array. + std::size_t pos = 0; + for (const auto& a : live_allocs_) { + if (a.start >= pos + size) { + result->start = pos; + result->end = pos + size; + live_allocs_.insert(*result); + return; + } + pos = a.end; + } + // No sufficiently wide gap was found before an existing live allocation, + // so we allocate the new array at the end of the allocation space. + // We may then have to grow total_size_. + total_size_ = std::max(total_size_, pos + size); + result->start = pos; + result->end = pos + size; + live_allocs_.insert(*result); + } + + void Deallocate(const Alloc& a) { + auto iter = std::lower_bound(live_allocs_.begin(), live_allocs_.end(), a); + CHECK(iter != live_allocs_.end()); + CHECK(*iter == a); + live_allocs_.erase(iter); + } + + std::size_t total_size() const { return total_size_; } + + private: + std::size_t total_size_; + std::set live_allocs_; +}; + +// Returns the required transient allocation size (in bytes) for a given array, +// or 0 if it's not a transient array. +std::size_t TransientArraySize(const Model& model, const string& array_name, + std::size_t transient_data_alignment) { + if (!IsAllocatableTransientArray(model, array_name)) { + return 0; + } + const auto& array = model.arrays.at(array_name); + CHECK(array->has_shape()) + << "Array '" << array_name << "' doesn't have a shape"; + if (array->data_type == ArrayDataType::kNone) { + // Catch a typical issue at the moment with RNN states + for (const auto& rnn_state : model.flags.rnn_states()) { + if (rnn_state.state_array() == array_name) { + LOG(FATAL) + << "A RNN state array, " << array_name << ", still does not " + << "have a known data type after all graph transformations have " + << "run. That's mostly a toco bug --- sorry. For now, you can " + << "work around this issue by adding manually_create:true in the " + << "--rnn_state description of this RNN state."; + } + } + LOG(FATAL) << "An array, " << array_name << ", still does not " + << "have a known data type after all graph transformations have " + << "run."; + } + const std::size_t elem_size = ElementSize(array->data_type); + const std::size_t raw_size = + elem_size * RequiredBufferSizeForShape(array->shape()); + const std::size_t rounded_size = + RoundUpToNextMultipleOf(raw_size, transient_data_alignment); + return rounded_size; +} + +// Allocates an array: call this for every array just before the first +// op where it is used. +void AllocateTransientArray(const Model& model, const string& array_name, + Allocator* allocator, + std::size_t transient_data_alignment) { + if (!IsAllocatableTransientArray(model, array_name)) { + return; + } + const std::size_t size = + TransientArraySize(model, array_name, transient_data_alignment); + const auto& array = model.arrays.at(array_name); + CHECK(!array->alloc); + allocator->Allocate(size, &array->GetOrCreateAlloc()); +} + +// Deallocates an array: call this for every array just after the last +// op where it is used. +void DeallocateTransientArray(const Model& model, const string& array_name, + Allocator* allocator) { + if (!IsAllocatableTransientArray(model, array_name)) { + return; + } + const auto& array = model.arrays.at(array_name); + CHECK(!!array->alloc); + allocator->Deallocate(*array->alloc); +} + +} // namespace + +void AllocateTransientArrays(Model* model, + std::size_t transient_data_alignment) { + // Precompute the lifespans for all arrays. + std::unordered_map array_lifespans; + ComputeArrayLifespans(*model, &array_lifespans); + + // In case of variable batch, our convention will be to compute the + // allocations for batch==1, then let the inference code multiply all + // the offsets by the actual runtime batch size. Conveniently, + // the variable_batch and batch flags are mutually exclusive, and the default + // value of batch is 1, so we have nothing special to do here. Let us + // just guard this assumption with a CHECK: + bool batchless_input_shapes = true; + for (const auto& input_array : model->flags.input_arrays()) { + if (input_array.shape().empty() || input_array.shape(0) != 1) { + batchless_input_shapes = false; + break; + } + } + CHECK(!model->flags.variable_batch() || batchless_input_shapes); + + Allocator allocator; + + // Construct a sorted map of array names, so that other layout engines can + // match exactly. + std::map ordered_arrays_map; + for (const auto& pair : model->arrays) { + ordered_arrays_map[pair.first] = pair.second.get(); + } + + // Allocate persistent arrays (like RNN states). For them, 'transient' + // is a misnormer, should read 'workspace'. + for (const auto& array_pair : ordered_arrays_map) { + const string& array_name = array_pair.first; + const auto& array_lifespan = array_lifespans.find(array_name)->second; + if (array_lifespan.persistent) { + AllocateTransientArray(*model, array_name, &allocator, + transient_data_alignment); + } + } + + for (std::size_t op_index = 0; op_index < model->operators.size(); + op_index++) { + const auto& op = model->operators[op_index]; + // Allocate those arrays whose lifespan starts exactly here. + for (const auto& input : op->inputs) { + if (StartsAt(array_lifespans[input], op_index)) { + AllocateTransientArray(*model, input, &allocator, + transient_data_alignment); + } + } + for (const auto& output : op->outputs) { + if (StartsAt(array_lifespans[output], op_index)) { + AllocateTransientArray(*model, output, &allocator, + transient_data_alignment); + } + } + // Deallocate those arrays whose lifespan ends exactly here. + for (const auto& input : op->inputs) { + if (EndsAt(array_lifespans[input], op_index)) { + DeallocateTransientArray(*model, input, &allocator); + } + } + for (const auto& output : op->outputs) { + if (EndsAt(array_lifespans[output], op_index)) { + DeallocateTransientArray(*model, output, &allocator); + } + } + } + + // Just out of curiosity (not used in the actual allocation process) + // evaluate the optimal total allocated size. + // First, compute the size of persistent arrays. + std::size_t optimal_transient_alloc_size = 0; + std::size_t persistent_alloc_size = 0; + for (const auto& array_pair : ordered_arrays_map) { + const string& array_name = array_pair.first; + const auto& array_lifespan = array_lifespans.find(array_name)->second; + if (array_lifespan.persistent) { + persistent_alloc_size += + TransientArraySize(*model, array_name, transient_data_alignment); + } + } + for (const auto& op : model->operators) { + // for each operator, compute the sum of the sizes of the array that must + // be live during the execution of this operator, plus the size of + // persistent arrays that must be live at all times. + std::size_t size = persistent_alloc_size; + for (const auto& input : op->inputs) { + if (!array_lifespans[input].persistent) { + size += TransientArraySize(*model, input, transient_data_alignment); + } + } + for (const auto& output : op->outputs) { + if (!array_lifespans[output].persistent) { + size += TransientArraySize(*model, output, transient_data_alignment); + } + } + // The optimal total size is the maximum of all operator-specific sizes. + optimal_transient_alloc_size = std::max(optimal_transient_alloc_size, size); + } + + model->transient_data_size = allocator.total_size(); + model->transient_data_alignment = transient_data_alignment; + CHECK_GE(model->transient_data_size, optimal_transient_alloc_size); + LOG(INFO) << "Total transient array allocated size: " + << model->transient_data_size << " bytes, " + << "theoretical optimal value: " << optimal_transient_alloc_size + << " bytes."; + CheckInvariants(*model); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h new file mode 100644 index 0000000000..12d0d0498f --- /dev/null +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ + +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +// We align the allocated sizes to the next multiple of a cache line, +// to get simple performance characteristics without side effects of +// accesses to one buffer on accesses to another buffer. +// That also takes care of data type alignment for any reasonable type +// (no reasonable data type should have alignment greater than a cache line). +// Here we make CPU-centric assumptions, in particular, we assume 64-byte cache +// lines. Getting this wrong by a factor of 2x (if this ever changes) wouldn't +// be terrible. +// Embedded architectures may use a different value for alignment. +constexpr std::size_t kDefaultTransientDataAlignment = 64; + +// Rounds up dividend to a value divisible by divisor. +inline std::size_t RoundUpToNextMultipleOf(std::size_t dividend, + std::size_t divisor) { + return ((dividend + divisor - 1) / divisor) * divisor; +} + +void AllocateTransientArrays(Model* model, + std::size_t transient_data_alignment); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h new file mode 100644 index 0000000000..28661d4ff0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/args.h @@ -0,0 +1,225 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This abstracts command line arguments in toco. +// Arg is a parseable type that can register a default value, be able to +// parse itself, and keep track of whether it was specified. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ + +#include +#include +#include +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" + +namespace toco { + +// Since std::vector is in the std namespace, and we are not allowed +// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type +// to use as the flag type: +struct IntList { + std::vector elements; +}; +struct StringMapList { + std::vector> elements; +}; + +// command_line_flags.h don't track whether or not a flag is specified. Arg +// contains the value (which will be default if not specified) and also +// whether the flag is specified. +// TODO(aselle): consider putting doc string and ability to construct the +// tensorflow argument into this, so declaration of parameters can be less +// distributed. +// Every template specialization of Arg is required to implement +// default_value(), specified(), value(), parse(), bind(). +template +class Arg final { + public: + explicit Arg(T default_ = T()) : value_(default_) {} + virtual ~Arg() {} + + // Provide default_value() to arg list + T default_value() const { return value_; } + // Return true if the command line argument was specified on the command line. + bool specified() const { return specified_; } + // Const reference to parsed value. + const T& value() const { return value_; } + + // Parsing callback for the tensorflow::Flags code + bool parse(T value_in) { + value_ = value_in; + specified_ = true; + return true; + } + + // Bind the parse member function so tensorflow::Flags can call it. + std::function bind() { + return std::bind(&Arg::parse, this, std::placeholders::_1); + } + + private: + // Becomes true after parsing if the value was specified + bool specified_ = false; + // Value of the argument (initialized to the default in the constructor). + T value_; +}; + +template <> +class Arg final { + public: + // Provide default_value() to arg list + string default_value() const { return ""; } + // Return true if the command line argument was specified on the command line. + bool specified() const { return specified_; } + // Bind the parse member function so tensorflow::Flags can call it. + bool parse(string text) { + parsed_value_.elements.clear(); + specified_ = true; + // strings::Split("") produces {""}, but we need {} on empty input. + // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could + // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements) + if (!text.empty()) { + int32 element; + for (absl::string_view part : absl::StrSplit(text, ',')) { + if (!SimpleAtoi(part, &element)) return false; + parsed_value_.elements.push_back(element); + } + } + return true; + } + + std::function bind() { + return std::bind(&Arg::parse, this, std::placeholders::_1); + } + + const toco::IntList& value() const { return parsed_value_; } + + private: + toco::IntList parsed_value_; + bool specified_ = false; +}; + +template <> +class Arg final { + public: + // Provide default_value() to StringMapList + string default_value() const { return ""; } + // Return true if the command line argument was specified on the command line. + bool specified() const { return specified_; } + // Bind the parse member function so tensorflow::Flags can call it. + + bool parse(string text) { + parsed_value_.elements.clear(); + specified_ = true; + + if (text.empty()) { + return true; + } + +#if defined(PLATFORM_GOOGLE) + std::vector outer_vector; + absl::string_view text_disposable_copy = text; + SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); + for (const absl::string_view& outer_member_stringpiece : outer_vector) { + string outer_member(outer_member_stringpiece); + if (outer_member.empty()) { + continue; + } + string outer_member_copy = outer_member; + absl::StripAsciiWhitespace(&outer_member); + if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; + if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; + const std::vector inner_fields_vector = + strings::Split(outer_member, ','); + + std::unordered_map element; + for (const string& member_field : inner_fields_vector) { + std::vector outer_member_key_value = + strings::Split(member_field, ':'); + if (outer_member_key_value.size() != 2) return false; + string& key = outer_member_key_value[0]; + string& value = outer_member_key_value[1]; + absl::StripAsciiWhitespace(&key); + absl::StripAsciiWhitespace(&value); + if (element.count(key) != 0) return false; + element[key] = value; + } + parsed_value_.elements.push_back(element); + } + return true; +#else + // TODO(aselle): Fix argument parsing when absl supports structuredline + fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__, + __LINE__); + abort(); +#endif + } + + std::function bind() { + return std::bind(&Arg::parse, this, std::placeholders::_1); + } + + const toco::StringMapList& value() const { return parsed_value_; } + + private: + toco::StringMapList parsed_value_; + bool specified_ = false; +}; + +// Flags that describe a model. See model_cmdline_flags.cc for details. +struct ParsedModelFlags { + Arg input_array; + Arg input_arrays; + Arg output_array; + Arg output_arrays; + Arg input_shapes; + Arg mean_value = Arg(0.f); + Arg mean_values; + Arg std_value = Arg(1.f); + Arg std_values; + Arg variable_batch = Arg(false); + Arg drop_control_dependency = Arg(false); + Arg input_shape; + Arg rnn_states; + Arg model_checks; + // Debugging output options + Arg graphviz_first_array; + Arg graphviz_last_array; + Arg dump_graphviz; + Arg dump_graphviz_video = Arg(false); +}; + +// Flags that describe the operation you would like to do (what conversion +// you want). See toco_cmdline_flags.cc for details. +struct ParsedTocoFlags { + Arg input_file; + Arg output_file; + Arg input_format; + Arg output_format; + // TODO(aselle): command_line_flags doesn't support doubles + Arg default_ranges_min = Arg(0.); + Arg default_ranges_max = Arg(0.); + Arg input_type; + Arg input_types; + Arg inference_type; + Arg drop_fake_quant = Arg(false); + Arg reorder_across_fake_quant = Arg(false); + Arg allow_custom_ops = Arg(false); +}; + +} // namespace toco +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc new file mode 100644 index 0000000000..f5e2868dc0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -0,0 +1,293 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +using toco::port::AppendF; +using toco::port::StringF; + +namespace toco { +namespace { + +class Color { + public: + Color() {} + Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {} + // Returns the string serialization of this color in graphviz format, + // for use as 'fillcolor' in boxes. + string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); } + // Returns the serialization in graphviz format of a suitable color to use + // 'fontcolor' in the same boxes. It should black or white, whichever offers + // the better contrast from FillColorString(). + string TextColorString() const { + // https://en.wikipedia.org/wiki/Relative_luminance + const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_; + const uint8 l = luminance > 128.f ? 0 : 255; + return StringF("%.2X%.2X%.2X", l, l, l); + } + + private: + uint8 r_ = 0, g_ = 0, b_ = 0; +}; + +struct NodeProperties { + // The text to display inside the box for this node. + string label; + // The color to use for this node; will be used as 'fillcolor' + // for its box. See Color::FillColorString. A suitable, different + // color will be chosen for the 'fontcolor' for the inside text + // label, see Color::TextColorString. + Color color; +}; + +// All colors in this file are from: +// https://material.io/guidelines/style/color.html + +Color GetColorForArray(const Model& model, const string& array_name) { + // Arrays involved in RNN back-edges have a different color + for (const auto& rnn_state : model.flags.rnn_states()) { + // RNN state, fed by a back-edge. Bold color. + if (array_name == rnn_state.state_array()) { + return Color(0x0F, 0x9D, 0x58); + } + // RNN back-edge source, feeding a RNN state. + // Light tone of the same color as RNN states. + if (array_name == rnn_state.back_edge_source_array()) { + return Color(0xB7, 0xE1, 0xCD); + } + } + // Constant parameter arrays have their own bold color + if (model.GetArray(array_name).buffer) { + return Color(0x42, 0x85, 0xF4); + } + // Remaining arrays are activations. + // We use gray colors for them because they are the majority + // of arrays so we want to highlight other arrays instead of them. + // First, we use a bolder gray for input/output arrays: + const auto& dump_options = *GraphVizDumpOptions::singleton(); + if (IsInputArray(model, array_name) || + array_name == dump_options.graphviz_first_array || + array_name == dump_options.graphviz_last_array) { + return Color(0x9E, 0x9E, 0x9E); + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return Color(0x9E, 0x9E, 0x9E); + } + } + // Remaining arrays are intermediate activation arrays. + // Lighter tone of the same grey as for input/output arrays: + // We want these to be very discrete. + return Color(0xF5, 0xF5, 0xF5); +} + +NodeProperties GetPropertiesForArray(const Model& model, + const string& array_name) { + NodeProperties node_properties; + node_properties.color = GetColorForArray(model, array_name); + node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); + + // Append array shape to the label. + auto& array = model.GetArray(array_name); + + if (array.data_type == ArrayDataType::kFloat) { + AppendF(&node_properties.label, "\\nType: float"); + } else if (array.data_type == ArrayDataType::kInt32) { + AppendF(&node_properties.label, "\\nType: int32"); + } else if (array.data_type == ArrayDataType::kUint8) { + AppendF(&node_properties.label, "\\nType: uint8"); + } + + if (array.has_shape()) { + auto& array_shape = array.shape(); + node_properties.label += "\\n["; + for (int id = 0; id < array_shape.dimensions_count(); id++) { + if (id == 0) { + AppendF(&node_properties.label, "%d", array_shape.dims(id)); + } else { + AppendF(&node_properties.label, "x%d", array_shape.dims(id)); + } + } + node_properties.label += "]"; + } + + if (array.minmax) { + AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]", + array.minmax->min, array.minmax->max); + } + + if (array.quantization_params) { + AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)", + array.quantization_params->scale, + array.quantization_params->zero_point); + } + + if (array.alloc) { + AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)", + array.alloc->start, array.alloc->end); + } + + return node_properties; +} + +NodeProperties GetPropertiesForOperator(const Operator& op) { + NodeProperties node_properties; + if (op.type == OperatorType::kTensorFlowUnsupported) { + node_properties.label = + static_cast(op).tensorflow_op; + } else { + node_properties.label = OperatorTypeName(op.type); + } + // Additional information for some of the operators. + switch (op.type) { + case OperatorType::kConv: { + const auto& conv_op = static_cast(op); + node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, + conv_op.stride_height, + conv_op.padding.type == PaddingType::kSame ? "S" : "V"); + break; + } + case OperatorType::kDepthwiseConv: { + const auto& conv_op = static_cast(op); + node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, + conv_op.stride_height, + conv_op.padding.type == PaddingType::kSame ? "S" : "V"); + break; + } + case OperatorType::kFullyConnected: { + node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + break; + } + default: + node_properties.color = Color(0xDB, 0x44, 0x37); + break; + } + + return node_properties; +} + +std::vector OperatorsToDump(const Model& model) { + const auto& dump_options = *GraphVizDumpOptions::singleton(); + bool first_specified = !dump_options.graphviz_first_array.empty(); + bool last_specified = !dump_options.graphviz_last_array.empty(); + CHECK_EQ(first_specified, last_specified); + std::vector ops_to_dump; + if (last_specified) { + // Return only the part of the graph between graphviz_first_array + // and graphviz_last_array. + CHECK(model.arrays.count(dump_options.graphviz_first_array)); + CHECK(model.arrays.count(dump_options.graphviz_last_array)); + std::unordered_set arrays_already_produced; + std::vector arrays_to_produce; + arrays_to_produce.push_back(dump_options.graphviz_last_array); + while (!arrays_to_produce.empty()) { + const string array = arrays_to_produce.back(); + arrays_to_produce.pop_back(); + CHECK(!arrays_already_produced.count(array)); + arrays_already_produced.insert(array); + const Operator* op = GetOpWithOutput(model, array); + if (!op) { + continue; + } + ops_to_dump.push_back(op); + for (const string& input : op->inputs) { + if (arrays_already_produced.count(input) || + input == dump_options.graphviz_first_array) { + continue; + } + arrays_to_produce.push_back(input); + } + } + } else { + // Return the whole graph. + for (const auto& op : model.operators) { + ops_to_dump.push_back(op.get()); + } + } + return ops_to_dump; +} + +} // namespace + +void DumpGraphviz(const Model& model, string* output_file_contents) { + AppendF(output_file_contents, "digraph Computegraph {\n"); + + constexpr char kNodeFormat[] = + "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", " + "fontcolor = \"#%sDD\"];\n"; + + constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n"; + + constexpr char kRNNBackEdgeFormat[] = + "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; + + std::vector ops_to_dump = OperatorsToDump(model); + std::set already_added_arrays; + for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) { + const Operator& op = *ops_to_dump[op_index]; + // Add node for operator. + auto op_properties = GetPropertiesForOperator(op); + string operator_id = StringF("op%05d", op_index); + AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, + "box", op_properties.color.FillColorString().c_str(), + op_properties.color.TextColorString().c_str()); + // Add nodes and edges for all inputs of the operator. + for (const auto& input : op.inputs) { + auto array_properties = GetPropertiesForArray(model, input); + if (!already_added_arrays.count(input)) { + AppendF(output_file_contents, kNodeFormat, input, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } + AppendF(output_file_contents, kEdgeFormat, input, operator_id); + already_added_arrays.insert(input); + } + // Add nodes and edges for all outputs of the operator. + for (const auto& output : op.outputs) { + auto array_properties = GetPropertiesForArray(model, output); + if (!already_added_arrays.count(output)) { + AppendF(output_file_contents, kNodeFormat, output, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } + AppendF(output_file_contents, kEdgeFormat, operator_id, output); + already_added_arrays.insert(output); + } + } + + for (const auto& rnn_state : model.flags.rnn_states()) { + AppendF(output_file_contents, kRNNBackEdgeFormat, + rnn_state.back_edge_source_array(), rnn_state.state_array()); + } + + AppendF(output_file_contents, "}\n"); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/contrib/lite/toco/dump_graphviz.h new file mode 100644 index 0000000000..0fb28e3de8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/dump_graphviz.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ + +#include + +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +void DumpGraphviz(const Model& model, string* output_file_contents); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc new file mode 100644 index 0000000000..16b9fa2260 --- /dev/null +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -0,0 +1,1570 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "google/protobuf/map.h" +#include "google/protobuf/text_format.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tensorflow_util.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::GraphDef; +using tensorflow::TensorProto; + +namespace toco { +namespace { + +// TensorFlow sometimes forbids what it calls "legacy scalars", +// which are 1-D shapes where the unique shape size is 1. +// See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. +// For that reason, we generally avoid creating legacy scalars, +// by detecting the case where a 1-D shape would be of size 1 and +// replacing that by a 0-D shape. +// However, there is a special circumstance where we must not do that +// and must unconditionally create a 1-D shape even if it is going to +// be of size 1: that is the case of bias vectors, with BiasAdd nodes. +// Indeed, TensorFlow requires bias vectors to be 1-D; in the case of +// a depth of 1, that would be a legacy scalar, so in that case we +// must go ahead and keep the shape 1-D, letting it be a legacy scalar. +enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars }; + +void ExportFloatArray(const Shape& input_shape, const float* input_data, + TensorProto* output_tensor, + LegacyScalarPolicy legacy_scalar_policy) { + output_tensor->set_dtype(DT_FLOAT); + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + auto* shape = output_tensor->mutable_tensor_shape(); + + const int kDims = input_shape.dimensions_count(); + if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars || + kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) { + for (int i = 0; i < kDims; ++i) { + shape->add_dim()->set_size(input_shape.dims(i)); + } + } + output_tensor->set_tensor_content( + string(reinterpret_cast(input_data), + sizeof(*input_data) * input_flat_size)); +} + +void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape, + const float* input_data, AxesOrder output_axes_order, + TensorProto* output_tensor, + LegacyScalarPolicy legacy_scalar_policy) { + CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order)); + output_tensor->set_dtype(DT_FLOAT); + CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order)); + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + + Shape shuffled_shape; + ShuffleDims(input_shape, input_axes_order, output_axes_order, + &shuffled_shape); + std::vector shuffled_data(input_flat_size); + ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape, + input_data, shuffled_data.data()); + + ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor, + legacy_scalar_policy); +} + +bool HasAlreadyExportedConst(const string& name, + const GraphDef& tensorflow_graph) { + for (const auto& node : tensorflow_graph.node()) { + if (node.op() == "Const" && node.name() == name) { + return true; + } + } + return false; +} + +void ConvertFloatTensorConst(const string& name, const Shape& input_shape, + const float* input_data, + AxesOrder input_axes_order, + AxesOrder output_axes_order, + GraphDef* tensorflow_graph, + LegacyScalarPolicy legacy_scalar_policy) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, + tensor, legacy_scalar_policy); +} + +void ConvertFloatTensorConst(const string& name, const Shape& input_shape, + const float* input_data, + AxesOrder input_axes_order, + AxesOrder output_axes_order, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, + tensor, LegacyScalarPolicy::kAvoidLegacyScalars); +} + +void ConvertFloatTensorConst(const Model& model, const string& name, + AxesOrder input_axes_order, + AxesOrder output_axes_order, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + CHECK(model.arrays.count(name)); + const auto& input_array = *model.arrays.at(name); + const auto& input_shape = input_array.shape(); + CHECK(input_array.buffer); + CHECK(input_array.buffer->type == ArrayDataType::kFloat); + const float* input_data = + input_array.GetBuffer().data.data(); + ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, + tensor, LegacyScalarPolicy::kAvoidLegacyScalars); +} + +void ConvertFloatTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + CHECK(model.arrays.count(name)); + const auto& input_array = *model.arrays.at(name); + const auto& input_shape = input_array.shape(); + CHECK(input_array.buffer); + CHECK(input_array.buffer->type == ArrayDataType::kFloat); + const float* input_data = + input_array.GetBuffer().data.data(); + ExportFloatArray(input_shape, input_data, tensor, + LegacyScalarPolicy::kAvoidLegacyScalars); +} + +void ConvertIntTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + CHECK(model.arrays.count(name)); + const auto& array = *model.arrays.at(name); + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + const auto& data = array.GetBuffer().data; + for (auto index : data) { + tensor->add_int_val(index); + } + const auto& array_shape = array.shape(); + auto* shape = tensor->mutable_tensor_shape(); + for (int i = 0; i < array_shape.dimensions_count(); i++) { + shape->add_dim()->set_size(array_shape.dims(i)); + } +} + +void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + const int32 data[2] = {cols, rows}; + tensor->set_tensor_content( + string(reinterpret_cast(data), sizeof(data))); + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(2); +} + +void CreateDummyConcatDimTensorConst(const string& name, int dim, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + tensor->add_int_val(dim); +} + +void CreateReshapeShapeTensorConst(const string& name, + const std::vector& shape, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + for (auto s : shape) { + tensor->add_int_val(s); + } + // TensorFlow sometimes forbids what it calls "legacy scalars", + // which are shapes of size 1 where the unique shape size is 1. + // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. + if (shape.size() > 1) { + auto* tensor_shape = tensor->mutable_tensor_shape(); + tensor_shape->add_dim()->set_size(shape.size()); + } +} + +string WalkUpToConstantArray(const Model& model, const string& name) { + const Array& original_array = model.GetArray(name); + if (original_array.buffer) { + return name; + } + const auto* op = GetOpWithOutput(model, name); + CHECK(op); + CHECK(op->type == OperatorType::kFakeQuant); + const string& input_of_fakequant_name = op->inputs[0]; + const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name); + CHECK(input_of_fakequant.buffer); + return input_of_fakequant_name; +} + +void ConvertConvOperator(const Model& model, const ConvOperator& src_op, + GraphDef* tensorflow_graph) { + const bool has_bias = src_op.inputs.size() >= 3; + string conv_output = src_op.outputs[0]; + if (has_bias) { + conv_output += "/conv"; + } + + auto* conv2d_op = tensorflow_graph->add_node(); + conv2d_op->set_op("Conv2D"); + conv2d_op->set_name(conv_output); + *conv2d_op->add_input() = src_op.inputs[0]; + *conv2d_op->add_input() = src_op.inputs[1]; + (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT); + const string& weights_array_name = + WalkUpToConstantArray(model, src_op.inputs[1]); + const auto& weights_array = model.GetArray(weights_array_name); + CHECK(weights_array.buffer->type == ArrayDataType::kFloat); + ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, + AxesOrder::kHWIO, tensorflow_graph); + auto& strides = (*conv2d_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*conv2d_op->mutable_attr())["padding"].set_s(padding); + + if (has_bias) { + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(src_op.outputs[0]); + biasadd_op->add_input(conv_output); + biasadd_op->add_input(src_op.inputs[2]); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + CHECK(model.arrays.count(src_op.inputs[2])); + const string& bias_array_name = + WalkUpToConstantArray(model, src_op.inputs[2]); + const auto& bias_array = model.GetArray(bias_array_name); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data, + AxesOrder::kOneAxis, AxesOrder::kOneAxis, + tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + } +} + +void ConvertDepthwiseConvOperator(const Model& model, + const DepthwiseConvOperator& src_op, + GraphDef* tensorflow_graph) { + const bool has_bias = src_op.inputs.size() >= 3; + string conv_output = src_op.outputs[0]; + if (has_bias) { + conv_output += "/conv"; + } + + auto* dc2d_op = tensorflow_graph->add_node(); + dc2d_op->set_op("DepthwiseConv2dNative"); + dc2d_op->set_name(conv_output); + *dc2d_op->add_input() = src_op.inputs[0]; + *dc2d_op->add_input() = src_op.inputs[1]; + (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT); + + // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth. + // We need to convert that to H x W x InputDepth x Multiplier. + // That's only a matter of constructing a Dims object; the actual + // array layout is the same. + CHECK(model.arrays.count(src_op.inputs[1])); + const string& src_weights_name = + WalkUpToConstantArray(model, src_op.inputs[1]); + const auto& src_weights_array = model.GetArray(src_weights_name); + const auto& src_weights_shape = src_weights_array.shape(); + CHECK_EQ(src_weights_shape.dimensions_count(), 4); + const Shape dst_weights_shape = + Shape({src_weights_shape.dims(1), src_weights_shape.dims(2), + src_weights_shape.dims(3) / src_op.depth_multiplier, + src_op.depth_multiplier}); + CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0); + CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) == + src_weights_shape.dims(3)); + CHECK_EQ(src_weights_shape.dims(0), 1); + + CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat); + const float* src_weights_data = + src_weights_array.GetBuffer().data.data(); + ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data, + AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph); + + auto& strides = (*dc2d_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*dc2d_op->mutable_attr())["padding"].set_s(padding); + + if (has_bias) { + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(src_op.outputs[0]); + biasadd_op->add_input(conv_output); + biasadd_op->add_input(src_op.inputs[2]); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + CHECK(model.arrays.count(src_op.inputs[2])); + const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]); + const auto& bias_array = model.GetArray(bias_name); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data, + AxesOrder::kOneAxis, AxesOrder::kOneAxis, + tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + } +} + +void ConvertDepthToSpaceOperator(const Model& model, + const DepthToSpaceOperator& src_op, + GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op("DepthToSpace"); + op->set_name(src_op.outputs[0]); + *op->add_input() = src_op.inputs[0]; + (*op->mutable_attr())["T"].set_type(DT_FLOAT); + (*op->mutable_attr())["block_size"].set_i(src_op.block_size); +} + +void ConvertSpaceToDepthOperator(const Model& model, + const SpaceToDepthOperator& src_op, + GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op("SpaceToDepth"); + op->set_name(src_op.outputs[0]); + *op->add_input() = src_op.inputs[0]; + (*op->mutable_attr())["T"].set_type(DT_FLOAT); + (*op->mutable_attr())["block_size"].set_i(src_op.block_size); +} + +void ConvertFullyConnectedOperator(const Model& model, + const FullyConnectedOperator& src_op, + GraphDef* tensorflow_graph) { + const string reshape_output = src_op.outputs[0] + "/reshape"; + const string reshape_shape = src_op.outputs[0] + "/reshape/shape"; + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(reshape_output); + reshape_op->add_input(src_op.inputs[0]); + reshape_op->add_input(reshape_shape); + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const bool has_bias = src_op.inputs.size() >= 3; + string matmul_output = src_op.outputs[0]; + if (has_bias) { + matmul_output += "/matmul"; + } + + auto* matmul_op = tensorflow_graph->add_node(); + matmul_op->set_op("MatMul"); + + matmul_op->set_name(matmul_output); + *matmul_op->add_input() = reshape_output; + *matmul_op->add_input() = src_op.inputs[1]; + (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*matmul_op->mutable_attr())["transpose_a"].set_b(false); + (*matmul_op->mutable_attr())["transpose_b"].set_b(false); + CHECK(model.arrays.count(src_op.inputs[1])); + const string& fc_weights_name = + WalkUpToConstantArray(model, src_op.inputs[1]); + const auto& fc_weights_array = *model.arrays.at(fc_weights_name); + const auto& fc_weights_shape = fc_weights_array.shape(); + CHECK_EQ(fc_weights_shape.dimensions_count(), 2); + CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, + tensorflow_graph); + + CHECK(fc_weights_array.buffer); + CHECK(fc_weights_array.buffer->type == ArrayDataType::kFloat); + const float* fc_weights_data = + fc_weights_array.GetBuffer().data.data(); + ConvertFloatTensorConst(fc_weights_name, fc_weights_shape, fc_weights_data, + AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); + + if (has_bias) { + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(src_op.outputs[0]); + biasadd_op->add_input(matmul_output); + biasadd_op->add_input(src_op.inputs[2]); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + CHECK(model.arrays.count(src_op.inputs[2])); + const auto& bias_array = *model.arrays.at(src_op.inputs[2]); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]), + bias_shape_1d, bias_data, AxesOrder::kOneAxis, + AxesOrder::kOneAxis, tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + } +} + +void ConvertAddOperator(const Model& model, const AddOperator& src_op, + GraphDef* tensorflow_graph) { + auto* add_op = tensorflow_graph->add_node(); + add_op->set_op("Add"); + add_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *add_op->add_input() = src_op.inputs[0]; + *add_op->add_input() = src_op.inputs[1]; + (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertMulOperator(const Model& model, const MulOperator& src_op, + GraphDef* tensorflow_graph) { + auto* add_op = tensorflow_graph->add_node(); + add_op->set_op("Mul"); + add_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *add_op->add_input() = src_op.inputs[0]; + *add_op->add_input() = src_op.inputs[1]; + (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertReluOperator(const ReluOperator& src_op, + GraphDef* tensorflow_graph) { + auto* relu_op = tensorflow_graph->add_node(); + relu_op->set_op("Relu"); + relu_op->set_name(src_op.outputs[0]); + *relu_op->add_input() = src_op.inputs[0]; + (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertRelu1Operator(const Relu1Operator& src_op, + GraphDef* tensorflow_graph) { + const string max_bounds = src_op.outputs[0] + "/max_bounds"; + const string min_bounds = src_op.outputs[0] + "/min_bounds"; + const string max_output = src_op.outputs[0] + "/max_output"; + + auto* max_bounds_const_op = tensorflow_graph->add_node(); + max_bounds_const_op->set_op("Const"); + max_bounds_const_op->set_name(max_bounds); + (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* max_bounds_const_op_tensor = + (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor(); + max_bounds_const_op_tensor->set_dtype(DT_FLOAT); + max_bounds_const_op_tensor->add_float_val(-1.0f); + + auto* min_bounds_const_op = tensorflow_graph->add_node(); + min_bounds_const_op->set_op("Const"); + min_bounds_const_op->set_name(min_bounds); + (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* min_bounds_const_op_tensor = + (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor(); + min_bounds_const_op_tensor->set_dtype(DT_FLOAT); + min_bounds_const_op_tensor->add_float_val(1.0f); + + auto* max_op = tensorflow_graph->add_node(); + max_op->set_op("Maximum"); + max_op->set_name(max_output); + *max_op->add_input() = src_op.inputs[0]; + *max_op->add_input() = max_bounds; + (*max_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* min_op = tensorflow_graph->add_node(); + min_op->set_op("Minimum"); + min_op->set_name(src_op.outputs[0]); + *min_op->add_input() = max_output; + *min_op->add_input() = min_bounds; + (*min_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertRelu6Operator(const Relu6Operator& src_op, + GraphDef* tensorflow_graph) { + auto* relu_op = tensorflow_graph->add_node(); + relu_op->set_op("Relu6"); + relu_op->set_name(src_op.outputs[0]); + *relu_op->add_input() = src_op.inputs[0]; + (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertLogisticOperator(const LogisticOperator& src_op, + GraphDef* tensorflow_graph) { + auto* relu_op = tensorflow_graph->add_node(); + relu_op->set_op("Sigmoid"); + relu_op->set_name(src_op.outputs[0]); + *relu_op->add_input() = src_op.inputs[0]; + (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertTanhOperator(const TanhOperator& src_op, + GraphDef* tensorflow_graph) { + auto* tanh_op = tensorflow_graph->add_node(); + tanh_op->set_op("Tanh"); + tanh_op->set_name(src_op.outputs[0]); + *tanh_op->add_input() = src_op.inputs[0]; + (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, + GraphDef* tensorflow_graph) { + string softmax_input; + Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); + if (providing_op->type == OperatorType::kTensorFlowReshape) { + softmax_input = src_op.inputs[0]; + } else { + // Insert a reshape operator that reduces the dimensions down to the 2 that + // are required for TensorFlow Logits. + const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape"; + const string softmax_size = src_op.outputs[0] + "/softmax_insert_size"; + softmax_input = reshape_output; + + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(reshape_output); + *reshape_op->add_input() = src_op.inputs[0]; + *reshape_op->add_input() = softmax_size; + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape(); + int32 flattened_size = 1; + for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { + flattened_size *= input_shape.dims(i); + } + const std::vector shape_data = { + flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)}; + CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); + } + + auto* softmax_op = tensorflow_graph->add_node(); + softmax_op->set_op("Softmax"); + softmax_op->set_name(src_op.outputs[0]); + *softmax_op->add_input() = softmax_input; + // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter + CHECK_EQ(src_op.beta, 1.f); + (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, + GraphDef* tensorflow_graph) { + const string square_output = src_op.outputs[0] + "/square"; + const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices"; + const string sum_output = src_op.outputs[0] + "/sum"; + const string rsqrt_output = src_op.outputs[0] + "/rsqrt"; + const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled"; + + auto* sum_reduction_indices_op = tensorflow_graph->add_node(); + sum_reduction_indices_op->set_op("Const"); + sum_reduction_indices_op->set_name(sum_reduction_indices); + (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* sum_reduction_indices_tensor = + (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor(); + sum_reduction_indices_tensor->set_dtype(DT_INT32); + auto* sum_reduction_indices_shape = + sum_reduction_indices_tensor->mutable_tensor_shape(); + auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim(); + sum_reduction_indices_dim->set_size(2); + sum_reduction_indices_tensor->add_int_val(0); + sum_reduction_indices_tensor->add_int_val(1); + + auto* square_op = tensorflow_graph->add_node(); + square_op->set_op("Square"); + square_op->set_name(square_output); + *square_op->add_input() = src_op.inputs[0]; + (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* sum_op = tensorflow_graph->add_node(); + sum_op->set_op("Sum"); + sum_op->set_name(sum_output); + *sum_op->add_input() = square_output; + *sum_op->add_input() = sum_reduction_indices; + (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* rsqrt_op = tensorflow_graph->add_node(); + rsqrt_op->set_op("Rsqrt"); + rsqrt_op->set_name(rsqrt_output); + *rsqrt_op->add_input() = sum_output; + (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* mul_op = tensorflow_graph->add_node(); + mul_op->set_op("Mul"); + mul_op->set_name(src_op.outputs[0]); + *mul_op->add_input() = src_op.inputs[0]; + *mul_op->add_input() = rsqrt_output; + (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertLocalResponseNormalizationOperator( + const LocalResponseNormalizationOperator& src_op, + GraphDef* tensorflow_graph) { + auto* lrn_op = tensorflow_graph->add_node(); + lrn_op->set_op("LRN"); + lrn_op->set_name(src_op.outputs[0]); + *lrn_op->add_input() = src_op.inputs[0]; + (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range); + (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias); + (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha); + (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta); +} + +void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, + GraphDef* tensorflow_graph) { + auto* fakequant_op = tensorflow_graph->add_node(); + fakequant_op->set_op("FakeQuantWithMinMaxArgs"); + fakequant_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *fakequant_op->add_input() = src_op.inputs[0]; + CHECK(src_op.minmax); + (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min); + (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max); +} + +void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, + GraphDef* tensorflow_graph) { + auto* maxpool_op = tensorflow_graph->add_node(); + maxpool_op->set_op("MaxPool"); + maxpool_op->set_name(src_op.outputs[0]); + *maxpool_op->add_input() = src_op.inputs[0]; + auto& strides = (*maxpool_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*maxpool_op->mutable_attr())["padding"].set_s(padding); + (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT); + auto& ksize = (*maxpool_op->mutable_attr())["ksize"]; + ksize.mutable_list()->add_i(1); + ksize.mutable_list()->add_i(src_op.kheight); + ksize.mutable_list()->add_i(src_op.kwidth); + ksize.mutable_list()->add_i(1); +} + +void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, + GraphDef* tensorflow_graph) { + auto* avgpool_op = tensorflow_graph->add_node(); + avgpool_op->set_op("AvgPool"); + avgpool_op->set_name(src_op.outputs[0]); + *avgpool_op->add_input() = src_op.inputs[0]; + auto& strides = (*avgpool_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*avgpool_op->mutable_attr())["padding"].set_s(padding); + (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT); + auto& ksize = (*avgpool_op->mutable_attr())["ksize"]; + ksize.mutable_list()->add_i(1); + ksize.mutable_list()->add_i(src_op.kheight); + ksize.mutable_list()->add_i(src_op.kwidth); + ksize.mutable_list()->add_i(1); +} + +void ConvertConcatenationOperator(const Model& model, + const ConcatenationOperator& src_op, + GraphDef* tensorflow_graph) { + auto* dc_op = tensorflow_graph->add_node(); + dc_op->set_op("ConcatV2"); + dc_op->set_name(src_op.outputs[0]); + const string dummy_concat_dim = src_op.outputs[0] + "/concat_dim"; + CreateDummyConcatDimTensorConst(dummy_concat_dim, src_op.concat_dim, + tensorflow_graph); + for (const auto& input : src_op.inputs) { + *dc_op->add_input() = input; + } + *dc_op->add_input() = dummy_concat_dim; + (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32); + (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size()); +} + +void ConvertTensorFlowReshapeOperator(const Model& model, + const TensorFlowReshapeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *reshape_op->add_input() = src_op.inputs[0]; + *reshape_op->add_input() = src_op.inputs[1]; + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + const auto& shape_array = model.GetArray(src_op.inputs[1]); + CHECK(shape_array.data_type == ArrayDataType::kInt32); + CHECK(shape_array.buffer != nullptr); + const auto& shape_data = shape_array.GetBuffer().data; + CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph); +} + +void ConvertL2PoolOperator(const L2PoolOperator& src_op, + GraphDef* tensorflow_graph) { + const string square_output = src_op.outputs[0] + "/square"; + const string avgpool_output = src_op.outputs[0] + "/avgpool"; + + auto* square_op = tensorflow_graph->add_node(); + square_op->set_op("Square"); + square_op->set_name(square_output); + *square_op->add_input() = src_op.inputs[0]; + (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); + + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + + auto* avgpool_op = tensorflow_graph->add_node(); + avgpool_op->set_op("AvgPool"); + avgpool_op->set_name(avgpool_output); + *avgpool_op->add_input() = square_output; + auto& strides = (*avgpool_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + + (*avgpool_op->mutable_attr())["padding"].set_s(padding); + (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT); + auto& ksize = (*avgpool_op->mutable_attr())["ksize"]; + ksize.mutable_list()->add_i(1); + ksize.mutable_list()->add_i(src_op.kheight); + ksize.mutable_list()->add_i(src_op.kwidth); + ksize.mutable_list()->add_i(1); + + auto* sqrt_op = tensorflow_graph->add_node(); + sqrt_op->set_op("Sqrt"); + sqrt_op->set_name(src_op.outputs[0]); + *sqrt_op->add_input() = avgpool_output; + (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, + GraphDef* tensorflow_graph) { + auto* square_op = tensorflow_graph->add_node(); + square_op->set_op("Square"); + square_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *square_op->add_input() = src_op.inputs[0]; + (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sqrt_op = tensorflow_graph->add_node(); + sqrt_op->set_op("Sqrt"); + sqrt_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *sqrt_op->add_input() = src_op.inputs[0]; + (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSplitOperator(const Model& model, + const TensorFlowSplitOperator& src_op, + GraphDef* tensorflow_graph) { + auto* split_op = tensorflow_graph->add_node(); + split_op->set_op("Split"); + split_op->set_name(src_op.outputs[0]); + for (const auto& input : src_op.inputs) { + *split_op->add_input() = input; + } + (*split_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split); + const auto& split_dim_array = model.GetArray(src_op.inputs[0]); + CHECK(split_dim_array.buffer); + CHECK(split_dim_array.data_type == ArrayDataType::kInt32); + const auto& split_dim_data = + split_dim_array.GetBuffer().data; + CHECK_EQ(split_dim_data.size(), 1); + const int split_dim = split_dim_data[0]; + CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim, + tensorflow_graph); +} + +tensorflow::DataType GetTensorFlowDataType(const Model& model, + const string& array_name) { + auto& dtype = model.GetArray(array_name).data_type; + CHECK(dtype == ArrayDataType::kFloat || dtype == ArrayDataType::kInt32 || + dtype == ArrayDataType::kUint8); + if (dtype == ArrayDataType::kFloat) { + return tensorflow::DT_FLOAT; + } else if (dtype == ArrayDataType::kInt32) { + return tensorflow::DT_INT32; + } else if (dtype == ArrayDataType::kUint8) { + return tensorflow::DT_UINT8; + } else { + LOG(FATAL) << "Wrong data type"; + } +} + +void ConvertCastOperator(const Model& model, const CastOperator& src_op, + GraphDef* tensorflow_graph) { + auto* cast_op = tensorflow_graph->add_node(); + cast_op->set_op("Cast"); + cast_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *cast_op->add_input() = src_op.inputs[0]; + + (*cast_op->mutable_attr())["DstT"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*cast_op->mutable_attr())["SrcT"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); +} + +void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, + GraphDef* tensorflow_graph) { + auto* floor_op = tensorflow_graph->add_node(); + floor_op->set_op("Floor"); + floor_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *floor_op->add_input() = src_op.inputs[0]; + (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, + GraphDef* tensorflow_graph) { + auto* gather_op = tensorflow_graph->add_node(); + gather_op->set_op("Gather"); + gather_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *gather_op->add_input() = src_op.inputs[0]; + *gather_op->add_input() = src_op.inputs[1]; + + (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*gather_op->mutable_attr())["Tparams"].set_type(params_type); +} + +void ConvertResizeBilinearOperator(const Model& model, + const ResizeBilinearOperator& src_op, + GraphDef* tensorflow_graph) { + auto* resize_op = tensorflow_graph->add_node(); + resize_op->set_op("ResizeBilinear"); + resize_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *resize_op->add_input() = src_op.inputs[0]; + *resize_op->add_input() = src_op.inputs[1]; + (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +namespace { +// TODO(aselle): Remove when available in absl +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b) { + if (a.empty() || b.empty()) return absl::string_view(); + + const char* pa = a.data(); + const char* pb = b.data(); + string::difference_type count = 0; + const string::difference_type limit = std::min(a.size(), b.size()); + while (count < limit && *pa == *pb) { + ++pa; + ++pb; + ++count; + } + + return absl::string_view(a.data(), count); +} +} // namespace + +void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, + GraphDef* tensorflow_graph) { + // Find the base name + const string base( + FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT], + src_op.outputs[LstmCellOperator::ACTIV_OUTPUT])); + + // Concatenate inputs + const string concat_output = base + "basic_lstm_cell/concat"; + // Op names have been chosen to match the tf.slim LSTM naming + // as closely as possible. + const int concat_dim = + model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) + ->shape() + .dimensions_count() - + 1; + // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat + // works the same since the tensor has the same underlying data layout. + const string concat_dim_output = concat_output + "/concat_dim"; + CreateDummyConcatDimTensorConst(concat_dim_output, concat_dim, + tensorflow_graph); + auto* concat_op = tensorflow_graph->add_node(); + concat_op->set_op("ConcatV2"); + concat_op->set_name(concat_output); + *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; + *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + *concat_op->add_input() = concat_dim_output; + (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32); + (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs + + // Write weights + const string weights_output = base + "weights"; + CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); + const auto& weights_array = + *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + // Convert 4D FullyConnected weights into 2D matrix + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 2); + CHECK(weights_array.buffer); + CHECK(weights_array.buffer->type == ArrayDataType::kFloat); + const float* weights_data = + weights_array.GetBuffer().data.data(); + ConvertFloatTensorConst(weights_output, weights_shape, weights_data, + AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); + + // Fully connected matrix multiply + const string matmul_output = base + "MatMul"; + auto* matmul_op = tensorflow_graph->add_node(); + matmul_op->set_op("MatMul"); + matmul_op->set_name(matmul_output); + *matmul_op->add_input() = concat_output; + *matmul_op->add_input() = weights_output; + (*matmul_op->mutable_attr())["transpose_a"].set_b(false); + (*matmul_op->mutable_attr())["transpose_b"].set_b(false); + (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); + + // Write biases + const string biases_output = base + "biases"; + CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT])); + const auto& bias_array = + *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data, + AxesOrder::kOneAxis, AxesOrder::kOneAxis, + tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + + // Add biases + string biasadd_output = base + "BiasAdd"; + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(biasadd_output); + biasadd_op->add_input(matmul_output); + biasadd_op->add_input(biases_output); + (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC"); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + + // Split + string split_dim_output = base + "split/split_dim"; + // The dimension is the same as the concatenation dimension + CreateDummyConcatDimTensorConst(split_dim_output, concat_dim, + tensorflow_graph); + string split_output = base + "split"; + auto* split_op = tensorflow_graph->add_node(); + split_op->set_op("Split"); + split_op->set_name(split_output); + *split_op->add_input() = split_dim_output; + *split_op->add_input() = biasadd_output; + (*split_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs + + // Activation functions and memory computations + const string tanh_0_output = base + "Tanh"; + auto* tanh_0_op = tensorflow_graph->add_node(); + tanh_0_op->set_op("Tanh"); + tanh_0_op->set_name(tanh_0_output); + *tanh_0_op->add_input() = split_output + ":1"; + (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string sigmoid_1_output = base + "Sigmoid_1"; + auto* logistic_1_op = tensorflow_graph->add_node(); + logistic_1_op->set_op("Sigmoid"); + logistic_1_op->set_name(sigmoid_1_output); + *logistic_1_op->add_input() = split_output; + (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string mul_1_output = base + "mul_1"; + auto* mul_1_op = tensorflow_graph->add_node(); + mul_1_op->set_op("Mul"); + mul_1_op->set_name(mul_1_output); + *mul_1_op->add_input() = sigmoid_1_output; + *mul_1_op->add_input() = tanh_0_output; + (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string sigmoid_0_output = base + "Sigmoid"; + auto* logistic_2_op = tensorflow_graph->add_node(); + logistic_2_op->set_op("Sigmoid"); + logistic_2_op->set_name(sigmoid_0_output); + *logistic_2_op->add_input() = split_output + ":2"; + (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string sigmoid_2_output = base + "Sigmoid_2"; + auto* logistic_3_op = tensorflow_graph->add_node(); + logistic_3_op->set_op("Sigmoid"); + logistic_3_op->set_name(sigmoid_2_output); + *logistic_3_op->add_input() = split_output + ":3"; + (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string mul_0_output = base + "mul"; + auto* mul_0_op = tensorflow_graph->add_node(); + mul_0_op->set_op("Mul"); + mul_0_op->set_name(mul_0_output); + *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT]; + *mul_0_op->add_input() = sigmoid_0_output; + (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT]; + auto* add_1_op = tensorflow_graph->add_node(); + add_1_op->set_op("Add"); + add_1_op->set_name(add_1_output); + *add_1_op->add_input() = mul_0_output; + *add_1_op->add_input() = mul_1_output; + (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string tanh_1_output = base + "Tanh_1"; + auto* tanh_1_op = tensorflow_graph->add_node(); + tanh_1_op->set_op("Tanh"); + tanh_1_op->set_name(tanh_1_output); + *tanh_1_op->add_input() = add_1_output; + (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]; + auto* mul_2_op = tensorflow_graph->add_node(); + mul_2_op->set_op("Mul"); + mul_2_op->set_name(mul_2_output); + *mul_2_op->add_input() = tanh_1_output; + *mul_2_op->add_input() = sigmoid_2_output; + (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSpaceToBatchNDOperator(const Model& model, + const SpaceToBatchNDOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("SpaceToBatchND"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); + (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); +} + +void ConvertBatchToSpaceNDOperator(const Model& model, + const BatchToSpaceNDOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("BatchToSpaceND"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); + (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32); +} + +void ConvertPadOperator(const Model& model, const PadOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Pad"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + // Create the params tensor. + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(src_op.inputs[1]); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size()); + for (int i = 0; i < src_op.left_padding.size(); ++i) { + tensor->add_int_val(src_op.left_padding[i]); + tensor->add_int_val(src_op.right_padding[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(src_op.left_padding.size()); + shape->add_dim()->set_size(2); +} + +void CreateSliceInput(const string& input_name, const std::vector& values, + GraphDef* tensorflow_graph) { + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(input_name); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + for (int i = 0; i < values.size(); ++i) { + tensor->add_int_val(values[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(values.size()); +} + +void ConvertStridedSliceOperator(const Model& model, + const StridedSliceOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("StridedSlice"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + *new_op->add_input() = src_op.inputs[3]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + (*new_op->mutable_attr())["Index"].set_type(DT_INT32); + (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask); + (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask); + (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask); + (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask); + (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask); + + // Create tensors for start/stop indices and strides. + CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph); + CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph); + CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph); +} + +void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Slice"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + (*new_op->mutable_attr())["Index"].set_type(DT_INT32); + + // Create tensors for begin and size inputs. + CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph); + CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph); +} + +void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Mean"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + // Create the params tensor. + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(src_op.inputs[1]); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + for (int i = 0; i < src_op.reduction_indices.size(); ++i) { + tensor->add_int_val(src_op.reduction_indices[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(src_op.reduction_indices.size()); +} + +void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Squeeze"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *new_op->add_input() = src_op.inputs[0]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"]; + for (int i : src_op.squeeze_dims) { + squeeze_dims.mutable_list()->add_i(i); + } +} + +void ConvertSubOperator(const Model& model, const SubOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Sub"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertTensorFlowMinimumOperator(const Model& model, + const TensorFlowMinimumOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Minimum"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertTensorFlowMaximumOperator(const Model& model, + const TensorFlowMaximumOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Maximum"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertOperator(const Model& model, const Operator& src_op, + GraphDef* tensorflow_graph) { + if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { + LOG(FATAL) + << "Unsupported: the input model has a fused activation function"; + } + + if (src_op.type == OperatorType::kConv) { + ConvertConvOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kDepthwiseConv) { + ConvertDepthwiseConvOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kDepthToSpace) { + ConvertDepthToSpaceOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSpaceToDepth) { + ConvertSpaceToDepthOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFullyConnected) { + ConvertFullyConnectedOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kAdd) { + ConvertAddOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kMul) { + ConvertMulOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRelu) { + ConvertReluOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRelu1) { + ConvertRelu1Operator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRelu6) { + ConvertRelu6Operator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogistic) { + ConvertLogisticOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTanh) { + ConvertTanhOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kL2Normalization) { + ConvertL2NormalizationOperator( + static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kSoftmax) { + ConvertSoftmaxOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLocalResponseNormalization) { + ConvertLocalResponseNormalizationOperator( + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLstmCell) { + ConvertLstmCellOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kMaxPool) { + ConvertMaxPoolOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kAveragePool) { + ConvertAveragePoolOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kConcatenation) { + ConvertConcatenationOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowReshape) { + ConvertTensorFlowReshapeOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kL2Pool) { + ConvertL2PoolOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowSquare) { + ConvertSquareOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowSqrt) { + ConvertSqrtOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowSplit) { + ConvertSplitOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFakeQuant) { + ConvertFakeQuantOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kCast) { + ConvertCastOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFloor) { + ConvertFloorOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kGather) { + ConvertGatherOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kResizeBilinear) { + ConvertResizeBilinearOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSpaceToBatchND) { + ConvertSpaceToBatchNDOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kBatchToSpaceND) { + ConvertBatchToSpaceNDOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kPad) { + ConvertPadOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kStridedSlice) { + ConvertStridedSliceOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kMean) { + ConvertMeanOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSub) { + ConvertSubOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowMinimum) { + ConvertTensorFlowMinimumOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowMaximum) { + ConvertTensorFlowMaximumOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSqueeze) { + ConvertSqueezeOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSlice) { + ConvertSliceOperator(model, static_cast(src_op), + tensorflow_graph); + } else { + LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); + } +} + +void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) { + auto* placeholder = tensorflow_graph->add_node(); + placeholder->set_op("Placeholder"); + (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); + placeholder->set_name(name); +} + +void AddPlaceholderForRNNState(const Model& model, const string& name, int size, + GraphDef* tensorflow_graph) { + auto* placeholder = tensorflow_graph->add_node(); + placeholder->set_op("Placeholder"); + placeholder->set_name(name); + (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); + + auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape(); + const auto& state_array = *model.arrays.at(name); + if (state_array.has_shape()) { + const auto& state_shape = state_array.shape(); + const int kDims = state_shape.dimensions_count(); + for (int i = 0; i < kDims; ++i) { + shape->add_dim()->set_size(state_shape.dims(i)); + } + } else { + shape->add_dim()->set_size(1); + shape->add_dim()->set_size(size); + } +} + +void ExportTensorFlowGraphDefImplementation(const Model& model, + GraphDef* tensorflow_graph) { + for (const auto& input_array : model.flags.input_arrays()) { + AddPlaceholder(input_array.name(), tensorflow_graph); + } + for (const auto& rnn_state : model.flags.rnn_states()) { + AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(), + tensorflow_graph); + } + for (const auto& op : model.operators) { + ConvertOperator(model, *op, tensorflow_graph); + } + // Generically export arrays that haven't been exported already + // by the above operators export. It's important that this comes + // after, as some operators need to export arrays that they reference + // in a specific way, rather than in the generic way done below. + for (const auto& array_pair : model.arrays) { + const string& array_name = array_pair.first; + const auto& array = *array_pair.second; + if (array.buffer) { + switch (array.data_type) { + case ArrayDataType::kFloat: + ConvertFloatTensorConst(model, array_name, tensorflow_graph); + break; + case ArrayDataType::kInt32: + ConvertIntTensorConst(model, array_name, tensorflow_graph); + break; + default: + break; + } + } + } +} +} // namespace + +void ExportTensorFlowGraphDef(const Model& model, + string* output_file_contents) { + CHECK(output_file_contents->empty()); + GraphDef tensorflow_graph; + ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph); + LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph); + CHECK(tensorflow_graph.SerializeToString(output_file_contents)); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h new file mode 100644 index 0000000000..eca9774576 --- /dev/null +++ b/tensorflow/contrib/lite/toco/export_tensorflow.h @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ + +#include +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h new file mode 100644 index 0000000000..3bc3295d04 --- /dev/null +++ b/tensorflow/contrib/lite/toco/format_port.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is used to provide equivalents of internal util::format::FormatF +// and util::format::AppendF. Unfortunately, type safety is not as good as a +// a full C++ example. +// TODO(aselle): When absl adds support for StrFormat, use that instead. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ + +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace toco { +namespace port { + +/// Identity (default case) +template +T IdentityOrConvertStringToRaw(T foo) { + return foo; +} + +// Overloaded case where we return std::string. +inline const char* IdentityOrConvertStringToRaw(const std::string& foo) { + return foo.c_str(); +} + +#if defined(PLATFORM_GOOGLE) +// Overloaded case where we return string. +inline const char* IdentityOrConvertStringToRaw(const string& foo) { + return foo.c_str(); +} +#endif // PLATFORM_GOOGLE +// Delegate to TensorFlow Appendf function until absl has an equivalent. +template +inline void AppendFHelper(string* destination, const char* fmt, + Args&&... args) { + tensorflow::strings::Appendf(destination, fmt, args...); +} + +// Specialization for no argument format string (avoid security bug). +inline void AppendFHelper(string* destination, const char* fmt) { + tensorflow::strings::Appendf(destination, "%s", fmt); +} + +// Append formatted string (with format fmt and args args) to the string +// pointed to by destination. fmt follows C printf semantics. +// One departure is that %s can be driven by a std::string or string. +template +inline void AppendF(string* destination, const char* fmt, Args&&... args) { + AppendFHelper(destination, fmt, IdentityOrConvertStringToRaw(args)...); +} + +// Return formatted string (with format fmt and args args). fmt follows C printf +// semantics. One departure is that %s can be driven by a std::string or string. +template +inline string StringF(const char* fmt, Args&&... args) { + string result; + AppendFHelper(&result, fmt, IdentityOrConvertStringToRaw(args)...); + return result; +} + +} // namespace port +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc new file mode 100644 index 0000000000..bf454c40c7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + const auto* conv_op = static_cast(conv_it->get()); + if (conv_op->stride_width != conv_op->stride_height) { + return false; + } + auto& weights_array = model->GetArray(conv_op->inputs[1]); + if (!weights_array.buffer) { + // Yield until the weights are resolved as a constant array. + return false; + } + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + if (weights_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } + // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. + AddMessageF( + "%s is purely convolutional (input/weights depth is 1), replacing it by " + "a DepthwiseConv.", + LogName(*conv_op)); + auto* depthwiseconv_op = new DepthwiseConvOperator; + // Conv and DepthwiseConv take the same inputs + depthwiseconv_op->inputs = conv_op->inputs; + // Conv may have a 2nd output for im2col + depthwiseconv_op->outputs = {conv_op->outputs[0]}; + if (conv_op->outputs.size() > 1) { + // delete the im2col array. + model->arrays.erase(conv_op->outputs[1]); + } + depthwiseconv_op->fused_activation_function = + conv_op->fused_activation_function; + // Let PropagateFixedSizes recompute fixed padding, just in case some day it + // may be different for Conv vs DepthwiseConv. + depthwiseconv_op->padding.type = conv_op->padding.type; + depthwiseconv_op->stride_height = conv_op->stride_height; + depthwiseconv_op->stride_width = conv_op->stride_width; + depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0); + // Replace the operator in the graph. + const auto depthwiseconv_it = + model->operators.emplace(conv_it, depthwiseconv_op); + conv_it = depthwiseconv_it + 1; + CHECK_EQ(conv_it->get(), conv_op); + model->operators.erase(conv_it); + // Shuffle the weights. + const auto& weights_shape = weights_array.shape(); + auto& weights_buffer = + weights_array.GetMutableBuffer(); + const std::vector& conv_weights_data = weights_buffer.data; + std::vector depthwise_conv_weights_data(conv_weights_data.size()); + const int depth = weights_shape.dims(0); + const int width = weights_shape.dims(1); + const int height = weights_shape.dims(2); + const int width_height = width * height; + for (int c = 0; c < depth; c++) { + for (int xy = 0; xy < width_height; xy++) { + depthwise_conv_weights_data[c + depth * xy] = + conv_weights_data[xy + width_height * c]; + } + } + *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth}; + weights_buffer.data = depthwise_conv_weights_data; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc new file mode 100644 index 0000000000..1735b51e5b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + auto* conv_op = static_cast(conv_it->get()); + if (conv_op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + const auto& weights_array = *model->arrays[conv_op->inputs[1]]; + if (!weights_array.has_shape()) { + // We need to yield until weights dims have been resolved, because + // from the weights dims we determine whether an im2col array is + // needed. + return false; + } + const auto& weights_shape = weights_array.shape(); + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && + conv_op->stride_height == 1) { + // 1x1 unstrided conv does not need an im2col array. + return false; + } + + // Create the im2col array. + CHECK_EQ(conv_op->outputs.size(), 1); + const string& im2col_array_name = + AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + conv_op->outputs.push_back(im2col_array_name); + AddMessageF( + "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " + "stride_height=%d", + LogName(*conv_op), kwidth, kheight, conv_op->stride_width, + conv_op->stride_height); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc new file mode 100644 index 0000000000..b89e3f5310 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template +void DequantizeBuffer(Array* array) { + const auto old_data = array->GetBuffer().data; + array->buffer = nullptr; + array->data_type = ArrayDataType::kFloat; + auto& new_data = array->GetMutableBuffer().data; + new_data.resize(old_data.size()); + const auto& qparams = array->GetQuantizationParams(); + for (int i = 0; i < old_data.size(); i++) { + new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point); + } +} + +std::vector>::iterator FindFirstOpWithInput( + Model* model, const string& array_name) { + for (auto it = model->operators.begin(); it != model->operators.end(); ++it) { + for (const auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model->operators.end(); +} + +void ClearArrayQuantizationParams(const string& array_name, Model* model) { + auto* array = model->arrays.at(array_name).get(); + CHECK(array->quantization_params); + for (auto& input_array : *model->flags.mutable_input_arrays()) { + if (input_array.name() == array_name) { + auto& qparams = *array->quantization_params; + const double new_std_value = 1. / qparams.scale; + const double new_mean_value = qparams.zero_point; + if (input_array.has_std_value()) { + CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001); + } else { + input_array.set_std_value(new_std_value); + } + if (input_array.has_mean_value()) { + CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001); + } else { + input_array.set_mean_value(new_mean_value); + } + } + } + array->quantization_params = nullptr; +} + +bool DequantizeArray(const string& array_name, + GraphTransformation* transformation, Model* model) { + auto* array = model->arrays.at(array_name).get(); + if (!array->quantization_params) { + return false; + } + transformation->AddMessageF("Dequantizing array: %s", array_name); + + // Dequantize any buffer + if (array->buffer) { + if (array->data_type == ArrayDataType::kUint8) { + DequantizeBuffer(array); + } else if (array->data_type == ArrayDataType::kInt32) { + DequantizeBuffer(array); + } else { + LOG(FATAL) << "Unhandled data type"; + } + CHECK(array->data_type == ArrayDataType::kFloat); + CHECK(array->buffer->type == ArrayDataType::kFloat); + + // Clear quantization params, officially makes this a non-quantized array. + ClearArrayQuantizationParams(array_name, model); + return true; + } else { + array->data_type = ArrayDataType::kFloat; + } + + // Clear quantization params, officially makes this a non-quantized array. + ClearArrayQuantizationParams(array_name, model); + + if (array->buffer) { + return true; + } + + auto* op_outputting_array = GetOpWithOutput(*model, array_name); + if (op_outputting_array) { + if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + return true; + } + } + + // If there was no minmax info, we can return now. Indeed, + // the below only serves to create a FakeQuant node, but some arrays are + // quantized without MinMax (see the CHECK above) and that corresponds to + // places where a FakeQuant node is actually not wanted, because the + // quantization params are meant to be inferred in another way (e.g. bias + // vector for a Conv op, see their special-casing in quantize.cc). + if (!array->minmax) { + return true; + } + + // Determine whether to insert a FakeQuant before or after + // this array. + bool must_insert_fakequant_before = false; + bool must_insert_fakequant_after = false; + if (IsInputArray(*model, array_name)) { + must_insert_fakequant_after = true; + } + for (const string& output_array : model->flags.output_arrays()) { + if (array_name == output_array) { + must_insert_fakequant_before = true; + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + must_insert_fakequant_after = true; + } + if (array_name == rnn_state.back_edge_source_array()) { + must_insert_fakequant_before = true; + } + } + CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after)); + + // Create and insert the FakeQuant node + auto* fakequant_op = new FakeQuantOperator; + model->operators.emplace(FindFirstOpWithInput(model, array_name), + fakequant_op); + const string& new_array_name = AvailableArrayName(*model, array_name); + auto& new_array = model->GetOrCreateArray(new_array_name); + new_array.data_type = ArrayDataType::kFloat; + new_array.copy_shape(array->shape()); + new_array.GetOrCreateMinMax() = array->GetMinMax(); + fakequant_op->minmax.reset(new MinMax); + *fakequant_op->minmax = array->GetMinMax(); + if (must_insert_fakequant_before) { + for (const auto& op : model->operators) { + for (string& output : op->outputs) { + if (output == array_name) { + output = new_array_name; + } + } + } + fakequant_op->inputs = {new_array_name}; + fakequant_op->outputs = {array_name}; + } else { + for (const auto& op : model->operators) { + for (string& input : op->inputs) { + if (input == array_name) { + input = new_array_name; + } + } + } + fakequant_op->inputs = {array_name}; + fakequant_op->outputs = {new_array_name}; + } + return true; +} + +} // namespace + +bool Dequantize::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + auto* op = op_it->get(); + + if (op->type == OperatorType::kDequantize) { + auto& input_array = model->GetArray(op->inputs[0]); + if (input_array.data_type == ArrayDataType::kFloat) { + return false; + } + if (input_array.final_data_type != ArrayDataType::kFloat) { + return false; + } + input_array.data_type = ArrayDataType::kFloat; + input_array.quantization_params = nullptr; + auto& output_array = model->GetArray(op->outputs[0]); + output_array.data_type = ArrayDataType::kFloat; + output_array.quantization_params = nullptr; + return RemoveTrivialPassthroughOp(this, model, op_index); + } + + std::vector arrays; + for (const string& input : op->inputs) { + arrays.push_back(input); + } + for (const string& output : op->outputs) { + arrays.push_back(output); + } + bool changed = false; + for (const string& array : arrays) { + changed |= DequantizeArray(array, this, model); + } + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc new file mode 100644 index 0000000000..fea360740f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool DropFakeQuant::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast(fakequant_base_op); + + if (!fakequant_op->minmax) { + return false; + } + + const auto& output_array = model->GetArray(fakequant_op->outputs[0]); + if (!output_array.minmax) { + return false; + } + + // Drop min/max inputs + for (int i = 1; i < fakequant_op->inputs.size(); i++) { + if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { + model->arrays.erase(fakequant_op->inputs[i]); + } + } + fakequant_op->inputs.resize(1); + + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc new file mode 100644 index 0000000000..a3ed6663bc --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + auto* conv_op = static_cast(conv_it->get()); + if (conv_op->outputs.size() < 2) { + // Conv op does not have im2col. + return false; + } + + // Drop the im2col array. + CHECK_EQ(conv_op->outputs.size(), 2); + model->arrays.erase(conv_op->outputs[1]); + conv_op->outputs.resize(1); + AddMessageF("Dropped an im2col array for %s", LogName(*conv_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc new file mode 100644 index 0000000000..badefeca88 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ProcessLinearOperator(Model* model, Operator* op) { + if (op->inputs.size() >= 3) { + return false; + } + const string& output_name = op->outputs[0]; + const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); + op->inputs.push_back(bias_name); + DCHECK_EQ(op->inputs.size(), 3); + auto& bias_array = model->GetOrCreateArray(bias_name); + bias_array.data_type = ArrayDataType::kFloat; + + return true; +} +} // namespace + +bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) { + auto* op = model->operators[op_index].get(); + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + if (ProcessLinearOperator(model, op)) { + AddMessageF("Added bias vector to %s", LogName(*op)); + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc new file mode 100644 index 0000000000..7a86510025 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto ac_it = model->operators.begin() + op_index; + const auto* ac_op = ac_it->get(); + + if (ac_op->type != OperatorType::kRelu6 && + ac_op->type != OperatorType::kRelu1 && + ac_op->type != OperatorType::kRelu) { + return false; + } + + // Find the op producing the array passed to this activation function + Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]); + + if (!op) return false; + + if (CountTrueOutputs(*model, *op) > 1) { + AddMessageF( + "Not fusing activation function into %s because it has more than one " + " consumed output", + LogName(*op)); + return false; + } + + CHECK_EQ(op->outputs[0], ac_op->inputs[0]); + + int count_ops_consuming_output = CountOpsWithInput(*model, ac_op->inputs[0]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not fusing activation function into %s because it is consumed by more " + "than 1 other operator", + LogName(*op)); + return false; + } + + if (op->fused_activation_function != FusedActivationFunctionType::kNone) { + AddMessageF( + "Not fusing activation function into %s because it already has a fused " + "activation function", + LogName(*op)); + return false; + } + + // TODO(dkalenichenko): Great many ops don't support activation function + // fusing. Switch to the whilelist approach instead. + if (op->type == OperatorType::kConcatenation || + op->type == OperatorType::kSlice) { + AddMessageF( + "Not fusing activation function because the %s op doesn't support it", + LogName(*op)); + return false; + } + + AddMessageF("Fusing activation function %s into the preceding %s", + LogName(*ac_op), LogName(*op)); + if (ac_op->type == OperatorType::kRelu6) { + op->fused_activation_function = FusedActivationFunctionType::kRelu6; + } else if (ac_op->type == OperatorType::kRelu1) { + op->fused_activation_function = FusedActivationFunctionType::kRelu1; + } else if (ac_op->type == OperatorType::kRelu) { + op->fused_activation_function = FusedActivationFunctionType::kRelu; + } else { + LOG(FATAL) << "Unhandled activation function type"; + } + model->arrays.erase(ac_op->inputs[0]); + op->outputs[0] = ac_op->outputs[0]; + model->operators.erase(ac_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc new file mode 100644 index 0000000000..4619d8bbee --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op, + const Operator* add_or_sub_op, + int index_of_constant_input) { + CHECK(add_or_sub_op->type == OperatorType::kAdd || + add_or_sub_op->type == OperatorType::kSub); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a subtraction, the constant input should be the right hand + // side. + // This should have been checked before this point. + CHECK(add_or_sub_op->type != OperatorType::kSub || + index_of_constant_input == 1); + if (following_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + const auto& weights = model->GetArray(following_op->inputs[1]); + auto& bias = model->GetArray(following_op->inputs[2]); + bias.minmax = nullptr; + const auto& operand = + model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); + // We're only supporting the case of a scalar operand. Should have + // been checked earlier. + CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); + + const float scalar_operand = + operand.GetBuffer().data[0]; + // At this point we reduce the case of subtraction to that of addition + // by negating the operand. + float add_scalar_operand = 0.f; + if (add_or_sub_op->type == OperatorType::kAdd) { + add_scalar_operand = scalar_operand; + } else if (add_or_sub_op->type == OperatorType::kSub && + index_of_constant_input == 1) { + add_scalar_operand = -scalar_operand; + } else { + LOG(FATAL) << "Should not get here"; + } + // From here on we are fusing an addition. add_or_sub_op->type does not + // matter anymore. + + const Shape& weights_shape = weights.shape(); + const Shape& bias_shape = bias.shape(); + const auto& weights_buffer = weights.GetBuffer(); + const float* const weights_data = weights_buffer.data.data(); + auto& bias_buffer = bias.GetMutableBuffer(); + float* const bias_data = bias_buffer.data.data(); + + if (following_op->type == OperatorType::kConv || + following_op->type == OperatorType::kFullyConnected) { + const int output_depth = weights_shape.dims(0); + // TODO(b/62904716): Bias array should become 1-D when padding removed. + CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1)); + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + for (int d = 0; d < output_depth; d++) { + float accumulation = 0; + for (int i = 0; i < weights_per_depth; i++) { + accumulation += + add_scalar_operand * weights_data[d * weights_per_depth + i]; + } + bias_data[d] += accumulation; + } + } else if (following_op->type == OperatorType::kDepthwiseConv) { + const int output_depth = + weights_shape.dims(weights_shape.dimensions_count() - 1); + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + for (int c = 0; c < output_depth; c++) { + float accumulation = 0; + for (int k = 0; k < weights_per_depth; k++) { + accumulation += add_scalar_operand * weights_data[k * output_depth + c]; + } + bias_data[c] += accumulation; + } + } else { + LOG(FATAL) << "Should not get here."; + } +} + +void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, + const Operator* mul_or_div_op, + int index_of_constant_input) { + CHECK(mul_or_div_op->type == OperatorType::kMul || + mul_or_div_op->type == OperatorType::kDiv); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a division, the constant input should be the right hand side. + // This should have been checked before this point. + CHECK(mul_or_div_op->type != OperatorType::kDiv || + index_of_constant_input == 1); + const auto& weights_name = following_op->inputs[1]; + const auto& bias_name = following_op->inputs[2]; + auto& weights = model->GetArray(weights_name); + DropMinMax(model, weights_name); + DropMinMax(model, bias_name); + const auto& operand = + model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); + // We're only supporting the case of a scalar operand. Should have + // been checked earlier. + CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); + + const float scalar_operand = + operand.GetBuffer().data[0]; + + float* weights_data = + weights.GetMutableBuffer().data.data(); + const int weights_size = RequiredBufferSizeForShape(weights.shape()); + for (int i = 0; i < weights_size; i++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[i] *= scalar_operand; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[i] /= scalar_operand; + } else { + LOG(FATAL) << "Should not get here"; + } + } +} + +} // namespace + +bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // We only can fuse an binary when the two operands break down as follows: + // 1. One operand is the (variable) output of a typical affine (linear plus + // bias) + // op of a finite list of possible types: at the moment Conv, + // DepthwiseConv and + // FullyConnected are supported. + // 2. The other operand is a constant param array. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can fuse into a constant. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // For division, we can only fuse if the denominator is constant. + if (binary_op->type == OperatorType::kDiv) { + if (index_of_constant_input != 1) { + AddMessageF("Not fusing %s because the denominator is not constant", + LogName(*binary_op)); + return false; + } + } + + const auto& operand_shape = + model->GetArray(binary_op->inputs[index_of_constant_input]).shape(); + for (const auto& dim : operand_shape.dims()) { + if (dim > 1) { + AddMessageF( + "Not fusing %s into the following affine op, because we only know " + "how to do so when the constant operand is a scalar", + LogName(*binary_op)); + return false; + } + } + + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF("Not fusing %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]); + + if (!following_op) { + AddMessageF( + "Not fusing %s because it is not consumed by exactly one other op", + LogName(*binary_op)); + return false; + } + + if (following_op->type != OperatorType::kConv && + following_op->type != OperatorType::kFullyConnected && + following_op->type != OperatorType::kDepthwiseConv) { + AddMessageF( + "Not fusing %s because the following %s is not of one of the supported " + "types", + LogName(*binary_op), LogName(*following_op)); + return false; + } + + if (following_op->inputs.size() < 3) { + AddMessageF( + "Not fusing %s because the following %s does not have a bias vector", + LogName(*following_op), LogName(*binary_op)); + return false; + } + + const auto& weights = model->GetArray(following_op->inputs[1]); + const auto& bias = model->GetArray(following_op->inputs[2]); + if (!weights.buffer || !bias.buffer) { + AddMessageF( + "Not fusing %s because the following %s has non-constant weights or " + "bias arrays", + LogName(*binary_op), LogName(*following_op)); + return false; + } + + // Try to fuse the binary params into the following op's params + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + if (following_op->type == OperatorType::kConv) { + if (static_cast(following_op)->padding.type != + PaddingType::kValid) { + AddMessageF( + "Not fusing %s because the following %s does not use VALID padding", + LogName(*binary_op), LogName(*following_op)); + return false; + } + } + if (following_op->type == OperatorType::kDepthwiseConv) { + if (static_cast(following_op)->padding.type != + PaddingType::kValid) { + AddMessageF( + "Not fusing %s because the following %s does not use VALID padding", + LogName(*binary_op), LogName(*following_op)); + return false; + } + } + FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op, + index_of_constant_input); + } else if (binary_op->type == OperatorType::kMul || + binary_op->type == OperatorType::kDiv) { + FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op, + index_of_constant_input); + } else { + LOG(FATAL) << "should not get here"; + } + + AddMessageF("Fusing %s into the following %s", LogName(*binary_op), + LogName(*following_op)); + + model->arrays.erase(binary_op->outputs[0]); + following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; + const auto& old_constant_param_name = + binary_op->inputs[index_of_constant_input]; + CHECK(IsConstantParameterArray(*model, old_constant_param_name)); + if (CountOpsWithInput(*model, old_constant_param_name) == 1) { + model->arrays.erase(old_constant_param_name); + } + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc new file mode 100644 index 0000000000..8948653ec3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -0,0 +1,326 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, + const Operator* add_or_sub_op, + int index_of_constant_input) { + CHECK(add_or_sub_op->type == OperatorType::kAdd || + add_or_sub_op->type == OperatorType::kSub); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + if (preceding_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + auto& bias = model->GetArray(preceding_op->inputs[2]); + bias.minmax = nullptr; + const auto& operand = + model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); + + const Shape& bias_shape = bias.shape(); + const Shape& operand_shape = operand.shape(); + auto& bias_buffer = bias.GetMutableBuffer(); + float* const bias_data = bias_buffer.data.data(); + const auto& operand_buffer = operand.GetBuffer(); + const float* const operand_data = operand_buffer.data.data(); + + // TODO(b/62904716): Bias array should become 1-D when padding removed. + const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); + CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); + + enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; + + const OpType optype = (add_or_sub_op->type == OperatorType::kAdd) + ? OpType::BiasPlusOperand + : (index_of_constant_input == 1) + ? OpType::BiasMinusOperand + : OpType::OperandMinusBias; + + for (int i = 0; i < depth; i++) { + float& bias_val = bias_data[i]; + const float operand_val = operand_data[i]; + if (optype == OpType::BiasPlusOperand) { + bias_val += operand_val; + } else if (optype == OpType::BiasMinusOperand) { + bias_val -= operand_val; + } else if (optype == OpType::OperandMinusBias) { + bias_val = operand_val - bias_val; + } else { + LOG(FATAL) << "Should not get here."; + } + } +} + +void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, + const Operator* mul_or_div_op, + int index_of_constant_input) { + CHECK(mul_or_div_op->type == OperatorType::kMul || + mul_or_div_op->type == OperatorType::kDiv); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a division, the constant input should be the right hand side. + // This should have been checked before this point. + CHECK(mul_or_div_op->type != OperatorType::kDiv || + index_of_constant_input == 1); + if (preceding_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + const auto& weights_name = preceding_op->inputs[1]; + const auto& bias_name = preceding_op->inputs[2]; + auto& weights = model->GetArray(weights_name); + DropMinMax(model, weights_name); + auto& bias = model->GetArray(bias_name); + DropMinMax(model, bias_name); + const auto& operand = + model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); + + const Shape& weights_shape = weights.shape(); + const Shape& bias_shape = bias.shape(); + const Shape& operand_shape = operand.shape(); + auto& weights_buffer = weights.GetMutableBuffer(); + float* const weights_data = weights_buffer.data.data(); + auto& bias_buffer = bias.GetMutableBuffer(); + float* const bias_data = bias_buffer.data.data(); + const auto& operand_buffer = operand.GetBuffer(); + const float* const operand_data = operand_buffer.data.data(); + + // We support broadcasting the operand along the depth dimension, + // when the operand's depth is 1. + int operand_channel_increment = 0; + if (operand_shape.dimensions_count() >= 1 && + operand_shape.dims(operand_shape.dimensions_count() - 1) == + bias_shape.dims(bias_shape.dimensions_count() - 1)) { + operand_channel_increment = 1; + } else if (operand_shape.dimensions_count() == 0 || + operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { + operand_channel_increment = 0; + } else { + LOG(FATAL) << "Operand shape mismatch."; + } + + int output_depth; + + if (preceding_op->type == OperatorType::kConv || + preceding_op->type == OperatorType::kFullyConnected) { + output_depth = weights_shape.dims(0); + } else if (preceding_op->type == OperatorType::kDepthwiseConv) { + output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1); + } else { + LOG(FATAL) << "Should not get here"; + } + + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + int operand_channel = 0; + for (int c = 0; c < output_depth; c++) { + if (mul_or_div_op->type == OperatorType::kMul) { + bias_data[c] *= operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + bias_data[c] /= operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + if (preceding_op->type == OperatorType::kConv || + preceding_op->type == OperatorType::kFullyConnected) { + for (int i = 0; i < weights_per_depth; i++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[c * weights_per_depth + i] *= + operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[c * weights_per_depth + i] /= + operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + } + } else if (preceding_op->type == OperatorType::kDepthwiseConv) { + for (int k = 0; k < weights_per_depth; k++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[k * output_depth + c] *= operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[k * output_depth + c] /= operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + } + } else { + LOG(FATAL) << "Should not get here"; + } + operand_channel += operand_channel_increment; + } +} +} // namespace + +bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // We only can fuse an binary when the two operands break down as follows: + // 1. One operand is the (variable) output of a typical affine (linear plus + // bias) + // op of a finite list of possible types: at the moment Conv, + // DepthwiseConv and + // FullyConnected are supported. + // 2. The other operand is a constant param array. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can fuse into a constant. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // For division, we can only fuse if the denominator is constant. + if (binary_op->type == OperatorType::kDiv) { + if (index_of_constant_input != 1) { + AddMessageF("Not fusing %s because the denominator is not constant", + LogName(*binary_op)); + return false; + } + } + + Operator* preceding_op = + GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]); + if (!preceding_op) { + AddMessageF("Not fusing %s because it is not the output of another op", + LogName(*binary_op)); + return false; + } + + for (const string& output_array : model->flags.output_arrays()) { + if (preceding_op->outputs[0] == output_array) { + return false; + } + } + + if (preceding_op->type != OperatorType::kConv && + preceding_op->type != OperatorType::kFullyConnected && + preceding_op->type != OperatorType::kDepthwiseConv) { + AddMessageF( + "Not fusing %s because the preceding %s is not of one of the supported " + "types", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + if (preceding_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not fusing %s because the preceding %s has a fused activation " + "function", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + if (preceding_op->inputs.size() < 3) { + AddMessageF( + "Not fusing %s because the preceding %s does not have a bias vector", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + const auto& weights = model->GetArray(preceding_op->inputs[1]); + const auto& bias = model->GetArray(preceding_op->inputs[2]); + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + if (!bias.buffer) { + AddMessageF( + "Not fusing %s because the preceding %s has a non-constant bias " + "array", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + } else { + if (!weights.buffer || !bias.buffer) { + AddMessageF( + "Not fusing %s because the preceding %s has non-constant weights or " + "bias arrays", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + } + + int count_ops_consuming_output = + CountOpsWithInput(*model, preceding_op->outputs[0]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not fusing %s because the output of the preceding %s is consumed by " + "another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), + LogName(*preceding_op)); + + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op, + index_of_constant_input); + } else if (binary_op->type == OperatorType::kMul || + binary_op->type == OperatorType::kDiv) { + FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op, + index_of_constant_input); + } else { + LOG(FATAL) << "should not get here"; + } + + model->arrays.erase(preceding_op->outputs[0]); + preceding_op->outputs[0] = binary_op->outputs[0]; + preceding_op->fused_activation_function = + binary_op->fused_activation_function; + const auto& old_constant_param_name = + binary_op->inputs[index_of_constant_input]; + CHECK(IsConstantParameterArray(*model, old_constant_param_name)); + if (CountOpsWithInput(*model, old_constant_param_name) == 1) { + model->arrays.erase(old_constant_param_name); + } + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc new file mode 100644 index 0000000000..323fec6cf8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void PrintModelStats(const string& label, const Model& model) { + int quantized_arrays = 0; + for (const auto& array : model.arrays) { + if (array.second->quantization_params) { + quantized_arrays++; + } + } + LOG(INFO) << label << ": " << model.operators.size() << " operators, " + << model.arrays.size() << " arrays (" << quantized_arrays + << " quantized)"; +} + +bool GraphTransformationsPass(int increment, Model* model, + const GraphTransformationsSet& transformations) { + CHECK(increment == 1 || increment == -1); + bool changed = false; + CHECK(!model->operators.empty()); + int op_index = increment == 1 ? 0 : model->operators.size() - 1; + while (true) { + bool changed_now = false; + // Loop over all transformations at the current position in the graph. + for (const auto& transformation : transformations) { + CHECK(!changed_now); + CHECK(transformation->Messages().empty()); + changed_now = transformation->Run(model, op_index); + if (changed_now) { + DumpGraphvizVideoFrame(*model); + CHECK(!model->operators.empty()); + op_index = std::min(op_index, model->operators.size() - 1); + // Uncomment for debugging + // CheckInvariants(*model); + } + const char* made_a_change_msg = + changed_now ? "made a change" : "did NOT make a change"; + const int log_level = + changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged; + for (const string& message : transformation->Messages()) { + VLOG(log_level) << transformation->Name() << " " << made_a_change_msg + << " at op_index=" << op_index << "/" + << model->operators.size() - 1 << ": " << message; + } + transformation->ClearMessages(); + if (changed_now) { + break; + } + } + if (changed_now) { + changed = true; + } else { + const int op_index_last = + increment == 1 ? model->operators.size() - 1 : 0; + if (op_index == op_index_last) { + break; + } + op_index += increment; + } + } + return changed; +} + +} // namespace + +void RunGraphTransformations(Model* model, const string& msg, + const GraphTransformationsSet& transformations) { + PrintModelStats(toco::port::StringF("Before %s", msg), *model); + int pass_index = 0; + while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, + transformations)) { + pass_index++; + const auto& label = + toco::port::StringF("After %s pass %d", msg, pass_index); + PrintModelStats(label, *model); + CheckInvariants(*model); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h new file mode 100644 index 0000000000..2cc24ff361 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" + +namespace toco { + +class GraphTransformation { + public: + virtual bool Run(Model* model, std::size_t op_index) = 0; + virtual const char* Name() const = 0; + virtual ~GraphTransformation() {} + // Returns the list of messages that this graph transformation + // generated since ClearMessages() was called. + const std::vector& Messages() const { return messages_; } + // Clears the list of messages; should be called after every + // run of this graph transformation. + void ClearMessages() { return messages_.clear(); } + // Adds a message; normally only called by the graph transformation + // itself during its run (this function could be protected). + template + void AddMessageF(const char* format, const Args&... args) { + return messages_.push_back(toco::port::StringF(format, args...)); + } + + protected: + GraphTransformation() {} + + // List of messages generated by this graph transformation. + std::vector messages_; + + private: + GraphTransformation(const GraphTransformation& other) = delete; + GraphTransformation(const GraphTransformation&& other) = delete; +}; + +class GraphTransformationsSet { + public: + // The choice of a container with fully-specified iteration order + // ensures that graph transformations are always run in the same order, + // which avoids having toco randomly fail or produce different results + // depending on the toolchain. Ideally success/results should be independent + // of the order in which graph transformations are run, but that's + // unfortunately not currently guaranteed to be the case. + using TransformationsContainer = + std::vector>; + + GraphTransformationsSet() {} + GraphTransformationsSet( + const std::initializer_list transformations) { + for (GraphTransformation* t : transformations) { + Add(t); + } + } + void Add(GraphTransformation* transformation) { + const string& name = transformation->Name(); + CHECK(!names_.count(name)); + names_.insert(name); + transformations_.emplace_back(transformation); + } + TransformationsContainer::const_iterator begin() const { + return transformations_.begin(); + } + TransformationsContainer::const_iterator end() const { + return transformations_.end(); + } + bool empty() const { return transformations_.empty(); } + + private: + GraphTransformationsSet(const GraphTransformationsSet& other) = delete; + GraphTransformationsSet(const GraphTransformationsSet&& other) = delete; + std::vector> transformations_; + // Names of transformations in the set. Only used to guard against dupes. + std::unordered_set names_; +}; + +// Run the given list of graph transformations on the model. +// The message is only for logging purposes. +// The transformations is a rvalue reference, indicating that +// nothing else will use these pointers. The user is supposed to +// construct GraphTransformation objects by using 'new', pass us +// the resulting raw pointers, and this RunGraphTransformations +// takes care of delete'ing these pointers. +void RunGraphTransformations(Model* model, const string& message, + const GraphTransformationsSet& transformations); + +#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ + class GTName : public GraphTransformation { \ + public: \ + bool Run(Model* model, std::size_t op_index) override; \ + const char* Name() const { return #GTName; } \ + }; + +// List of all graph transformations +DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) +DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) +DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) +DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) +DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) +DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) +DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) +DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) +DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) +DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) +DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) +DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) +DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp) +DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) +DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) +DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) +DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) +DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) +DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape) +DECLARE_GRAPH_TRANSFORMATION(Dequantize) + +class ResolveReshapeAttributes : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "ResolveReshapeAttributes"; } +}; + +class RemoveTrivialReshape : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "RemoveTrivialReshape"; } + bool treat_expand_dims_as_trivial() const { + return treat_expand_dims_as_trivial_; + } + void set_treat_expand_dims_as_trivial(bool val) { + treat_expand_dims_as_trivial_ = val; + } + + private: + bool treat_expand_dims_as_trivial_ = false; +}; + +#undef DECLARE_GRAPH_TRANSFORMATION + +} // end namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc new file mode 100644 index 0000000000..d44b5dc7b0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -0,0 +1,229 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) { + if (op->outputs.size() != 2) { + return false; + } + auto& im2col_array = model->GetArray(op->outputs[1]); + if (im2col_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!im2col_array.minmax); + auto& im2col_minmax = im2col_array.GetOrCreateMinMax(); + im2col_minmax.min = input_minmax.min; + im2col_minmax.max = input_minmax.max; + return true; +} + +bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min >= 0. ? 0. : -1.; + output_minmax.max = input_minmax.max <= 0. ? 0. : 1.; + return true; +} + +bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { + // Do not early return if the output already has min/max: + // we may still need to adjust the inputs min/max. + bool has_minmax = false; + double overall_min = std::numeric_limits::infinity(); + double overall_max = -std::numeric_limits::infinity(); + for (const auto& input : op->inputs) { + if (model->GetArray(input).minmax) { + has_minmax = true; + const auto* minmax = model->GetArray(input).minmax.get(); + if (minmax) { + overall_min = std::min(overall_min, minmax->min); + overall_max = std::max(overall_max, minmax->max); + } + } + } + auto& output = model->GetArray(op->outputs[0]); + if (output.minmax) { + has_minmax = true; + const auto* minmax = model->GetArray(op->outputs[0]).minmax.get(); + if (minmax) { + overall_min = std::min(overall_min, minmax->min); + overall_max = std::max(overall_max, minmax->max); + } + } + if (!has_minmax) { + return false; + } + MinMax overall_minmax; + overall_minmax.min = overall_min; + overall_minmax.max = overall_max; + bool changed = false; + for (const auto& input : op->inputs) { + auto& array = model->GetArray(input); + if (!array.minmax) { + changed = true; + } else if (!(overall_minmax == array.GetMinMax())) { + changed = true; + LOG(WARNING) + << "Tweaking the MinMax of array " << input << ", which is " + << "an input to " << LogName(*op) << ", because we want all inputs " + << "and outputs of a Concatenation operator to have the same MinMax " + << "so that it can be implemented as a pure byte-copy, no " + "arithmetic."; + } + array.GetOrCreateMinMax() = overall_minmax; + } + if (!output.minmax) { + changed = true; + } else if (!(overall_minmax == output.GetMinMax())) { + changed = true; + LOG(WARNING) + << "Tweaking the MinMax of the output array of " << LogName(*op) + << ", because we want all inputs " + << "and outputs of a Concatenation operator to have the same MinMax " + << "so that it can be implemented as a pure byte-copy, no arithmetic."; + } + output.GetOrCreateMinMax() = overall_minmax; + + return changed; +} + +// The output of average or max pooling is within the same range as its input. +bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = std::min(input_minmax.min, 0.); + output_minmax.max = std::max(input_minmax.max, 0.); + return true; +} + +bool HardcodeMinMaxForReshape(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min; + output_minmax.max = input_minmax.max; + return true; +} + +bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, + double max) { + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = min; + output_minmax.max = max; + return true; +} +} // namespace + +bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + bool changed = false; + switch (op->type) { + case OperatorType::kConv: + changed = HardcodeMinMaxForIm2colArray(model, op); + break; + + case OperatorType::kL2Normalization: + changed = HardcodeMinMaxForL2Normalization(model, op); + break; + + case OperatorType::kConcatenation: + changed = HardcodeMinMaxForConcatenation(model, op); + break; + + case OperatorType::kAveragePool: + case OperatorType::kMaxPool: + changed = HardcodeMinMaxForAverageOrMaxPool(model, op); + break; + + case OperatorType::kTensorFlowReshape: + changed = HardcodeMinMaxForReshape(model, op); + break; + + case OperatorType::kLogistic: + // We hardcode quantization_params to: zero_point=0, scale=1/256. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); + break; + + case OperatorType::kSoftmax: + // We hardcode quantization_params to: zero_point=0, scale=1/256. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); + break; + + default: + break; + } + if (changed) { + AddMessageF("Hardcoded min-max through %s", LogName(*op)); + } + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc new file mode 100644 index 0000000000..01b75e37c6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} +} // namespace + +bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { + const auto div_it = model->operators.begin() + op_index; + const auto* div_or_mul_op = div_it->get(); + OperatorType expected_op_type_producing_div_or_mul_input; + if (div_or_mul_op->type == OperatorType::kDiv) { + expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + } else if (div_or_mul_op->type == OperatorType::kMul) { + expected_op_type_producing_div_or_mul_input = + OperatorType::kTensorFlowRsqrt; + } else { + return false; + } + CHECK_EQ(div_or_mul_op->inputs.size(), 2); + Operator* op_producing_div_or_mul_input[2] = { + GetOpWithOutput(*model, div_or_mul_op->inputs[0]), + GetOpWithOutput(*model, div_or_mul_op->inputs[1]), + }; + if (!op_producing_div_or_mul_input[1] || + op_producing_div_or_mul_input[1]->type != + expected_op_type_producing_div_or_mul_input) { + return false; + } + Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; + CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); + Operator* op_producing_sqrt_or_rsqrt_input = + GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); + if (!op_producing_sqrt_or_rsqrt_input) { + return false; + } + + // There may be an Add or a Maximum here, adding or clamping to a "small" + // constant scalar. + // Reported bug: b/29395854 + Operator* add_op = nullptr; + Operator* op_producing_add_input = nullptr; + if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || + op_producing_sqrt_or_rsqrt_input->type == + OperatorType::kTensorFlowMaximum) { + add_op = op_producing_sqrt_or_rsqrt_input; + bool add_can_be_removed = false; + CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); + for (int i = 0; i < 2; i++) { + const auto& input_array = + model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]); + if (!input_array.buffer) { + continue; + } + if (input_array.buffer->type != ArrayDataType::kFloat) { + continue; + } + if (RequiredBufferSizeForShape(input_array.shape()) != 1) { + continue; + } + const auto& input_float_data = + input_array.GetBuffer().data; + if (std::abs(input_float_data[0]) > 1e-3f) { + continue; + } + add_can_be_removed = true; + op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]); + break; + } + if (!add_can_be_removed) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph " + " because the operator producing the input to the square root, %s," + ", does not match the expected pattern", + LogName(*op_producing_sqrt_or_rsqrt_input)); + return false; + } + } + + Operator* sum_op = + add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; + if (sum_op->type != OperatorType::kTensorFlowSum) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: " + "expected Sum op, got %s", + LogName(*sum_op)); + return false; + } + + Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); + if (square_op->type != OperatorType::kTensorFlowSquare) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: " + "expected Square op, got %s", + LogName(*square_op)); + return false; + } + + CHECK_EQ(square_op->inputs.size(), 1); + + if (square_op->inputs[0] != div_or_mul_op->inputs[0]) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: %s does not " + "take the same input as the Mul/Div node", + LogName(*square_op)); + return false; + } + + // Create and emplace the new L2Normalization + auto* l2norm_op = new L2NormalizationOperator; + l2norm_op->inputs = {div_or_mul_op->inputs[0]}; + l2norm_op->outputs = div_or_mul_op->outputs; + model->operators.emplace(div_it, l2norm_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); + + // Erase the subgraph that is now replaced by L2Normalization + model->operators.erase(FindOperator(model, square_op)); + model->arrays.erase(sum_op->inputs[0]); + if (sum_op->inputs.size() > 1) { + model->arrays.erase(sum_op->inputs[1]); + } + model->operators.erase(FindOperator(model, sum_op)); + if (add_op) { + model->arrays.erase(add_op->inputs[0]); + model->arrays.erase(add_op->inputs[1]); + model->operators.erase(FindOperator(model, add_op)); + } + model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]); + model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); + model->arrays.erase(div_or_mul_op->inputs[1]); + model->operators.erase(FindOperator(model, div_or_mul_op)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc new file mode 100644 index 0000000000..1865416fc2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} +} // namespace + +bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { + const auto sqrt_it = model->operators.begin() + op_index; + const auto* sqrt_op = sqrt_it->get(); + if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + return false; + } + + CHECK_EQ(sqrt_op->inputs.size(), 1); + CHECK_EQ(sqrt_op->outputs.size(), 1); + + const AveragePoolOperator* avpool_op; + const Operator* square_op; + + Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, got %s", + LogName(*prev_to_sqrt_op)); + return false; + } + + avpool_op = static_cast(prev_to_sqrt_op); + CHECK_EQ(avpool_op->inputs.size(), 1); + + square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); + CHECK_EQ(square_op->inputs.size(), 1); + if (square_op->type != OperatorType::kTensorFlowSquare) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected Square op, got %s", + LogName(*square_op)); + return false; + } + + // Create and emplace L2Pool node. + auto* l2pool_op = new L2PoolOperator; + + l2pool_op->inputs = {square_op->inputs[0]}; + l2pool_op->outputs = sqrt_op->outputs; + + l2pool_op->padding.type = avpool_op->padding.type; + // Note that we do not setup avpool_op->padding.fixed here. This is done by + // the PropagateFixedSizes graph transformation. + + l2pool_op->stride_height = avpool_op->stride_height; + l2pool_op->stride_width = avpool_op->stride_width; + l2pool_op->kheight = avpool_op->kheight; + l2pool_op->kwidth = avpool_op->kwidth; + model->operators.emplace(sqrt_it, l2pool_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op)); + + // Erase intermediate arrays, keeping input to square op. + model->arrays.erase(avpool_op->inputs[0]); + model->arrays.erase(sqrt_op->inputs[0]); + + // Erase three operators being replaced. + model->operators.erase(FindOperator(model, square_op)); + model->operators.erase(FindOperator(model, avpool_op)); + model->operators.erase(FindOperator(model, sqrt_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc new file mode 100644 index 0000000000..082820fddc --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -0,0 +1,396 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator& op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == &op) { + break; + } + } + return it; +} + +bool GetStateArrayForBackEdge(const Model& model, + const string& back_edge_source_array, + string* state_array = nullptr) { + for (const auto& rnn_state : model.flags.rnn_states()) { + if (back_edge_source_array == rnn_state.back_edge_source_array()) { + // Found LSTM cell output + if (state_array) { + *state_array = rnn_state.state_array(); + } + return true; + } + } + return false; +} + +// Returns true if the given operator has exactly 1 input, and is connected to +// the given op_type. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType op_type, Operator** connected_op) { + // Check for required number of inputs + if (op.inputs.size() != 1) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (connected_op) { + *connected_op = x; + } + + return true; +} + +// Returns true if the given operator has exactly 2 inputs, which are connected +// to the given op_types. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType a_op_type, Operator** a_op, + OperatorType b_op_type, Operator** b_op) { + // Check for required number of inputs + if (op.inputs.size() != 2) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != a_op_type)) { + return false; + } + + // Check if second input is disconnected/connected to an operator + Operator* y = GetOpWithOutput(model, op.inputs[1]); + if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { + return false; + } + if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + return false; + } + + // Check that second operator, if connected, is of correct type + if ((y != nullptr) && (y->type != b_op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (a_op != nullptr) { + *a_op = x; + } + if (b_op != nullptr) { + *b_op = y; + } + return true; +} + +// Returns true if the given operator has exactly 3 inputs, which are connected +// to the given op_types. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType a_op_type, Operator** a_op, + OperatorType b_op_type, Operator** b_op, + OperatorType c_op_type, Operator** c_op) { + // Check for required number of inputs + if (op.inputs.size() != 3) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != a_op_type)) { + return false; + } + + // Check if second input is disconnected/connected to an operator + Operator* y = GetOpWithOutput(model, op.inputs[1]); + if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { + return false; + } + if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + return false; + } + + // Check that second operator, if connected, is of correct type + if ((y != nullptr) && (y->type != b_op_type)) { + return false; + } + + // Check if third input is disconnected/connected to an operator + Operator* z = GetOpWithOutput(model, op.inputs[2]); + if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { + return false; + } + if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { + return false; + } + + // Check that third operator, if connected, is of correct type + if ((z != nullptr) && (z->type != c_op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (a_op != nullptr) { + *a_op = x; + } + if (b_op != nullptr) { + *b_op = y; + } + if (c_op != nullptr) { + *c_op = z; + } + return true; +} + +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b) { + if (a.empty() || b.empty()) return absl::string_view(); + + const char* pa = a.data(); + const char* pb = b.data(); + size_t count = 0; + const ssize_t limit = std::min(a.size(), b.size()); + while (count < limit && *pa == *pb) { + ++pa; + ++pb; + ++count; + } + + return absl::string_view(a.data(), count); +} + +} // namespace + +bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { + // This LSTM cell identification method is not invariant to commutation of + // commutative operator inputs. For example, if input[0] and input[1] of the + // final output multiplication were swapped, this method would not identify it + // as an LSTM cell. This is OK in most cases, because + // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way. + + // Final output multiply + auto op_it = model->operators.begin() + op_index; + Operator* final_output_mul = op_it->get(); + if (final_output_mul->type != OperatorType::kMul) { + return false; + } + Operator *state_output_tanh, *fc_output_sig; + if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, + &state_output_tanh, OperatorType::kLogistic, + &fc_output_sig)) { + return false; + } + + // State output TanH + // (We don't count an operator as ID'd until we verify it has the correct + // operator types feeding into it.) + Operator* state_combine_add; + if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd, + &state_combine_add)) { + return false; + } + string prev_state; + if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], + &prev_state)) { + return false; + } + + // State forget & remember addition + Operator *state_forget_mul, *state_remember_mul; + if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul, + &state_forget_mul, OperatorType::kMul, + &state_remember_mul)) { + return false; + } + if (state_forget_mul->inputs[0] != prev_state) { + return false; + } + + // State forget gate + Operator* state_forget_sig; + if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone, + nullptr, OperatorType::kLogistic, + &state_forget_sig)) { + return false; + } + + // State remember gate + Operator *state_remember_sig, *state_info_tanh; + if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic, + &state_remember_sig, OperatorType::kTanh, + &state_info_tanh)) { + return false; + } + + // State remember "information" activation function + Operator* fc_output_split; + if (!MatchOperatorInputs(*state_info_tanh, *model, + OperatorType::kTensorFlowSplit, &fc_output_split)) { + return false; + } + // State remember gate activation function + Operator* tmp; + if (!MatchOperatorInputs(*state_remember_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // State forget gate activation function + if (!MatchOperatorInputs(*state_forget_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // Fully connected output activation function + if (!MatchOperatorInputs(*fc_output_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // Fully connected output split + Operator* fully_connected; + if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone, + nullptr, OperatorType::kFullyConnected, + &fully_connected)) { + return false; + } + + // Fully connected op + Operator* concat_inputs; + if (!MatchOperatorInputs(*fully_connected, *model, + OperatorType::kConcatenation, &concat_inputs, + OperatorType::kNone, nullptr, OperatorType::kNone, + nullptr)) { + return false; + } + + // Emplace a new LSTM cell operator + auto* lstm_cell_op = new LstmCellOperator; + lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); + lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0]; + lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = + concat_inputs->inputs[1]; + lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = + fully_connected->inputs[1]; + lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = + fully_connected->inputs[2]; + lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state; + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = + state_output_tanh->inputs[0]; + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = + final_output_mul->outputs[0]; + model->operators.emplace(op_it, lstm_cell_op); + AddMessageF("Creating %s replacing equivalent subgraph", + LogName(*lstm_cell_op)); + + // Create temp arrays used internally during runtime. + const string base_name(FindLongestCommonPrefix( + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT], + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT])); + const string& concat_temp_array_name = + AvailableArrayName(*model, base_name + "concat_temp"); + model->GetOrCreateArray(concat_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; + const string& activ_temp_array_name = + AvailableArrayName(*model, base_name + "activ_temp"); + model->GetOrCreateArray(activ_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name; + AddMessageF("Created temp outputs %s and %s on operator %s", + concat_temp_array_name, activ_temp_array_name, + LogName(*lstm_cell_op)); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + model->operators.erase(FindOperator(model, *final_output_mul)); + DeleteArrayIfUnused(state_output_tanh->outputs[0], model); + DeleteArrayIfUnused(fc_output_sig->outputs[0], model); + model->operators.erase(FindOperator(model, *state_output_tanh)); + model->operators.erase(FindOperator(model, *fc_output_sig)); + model->operators.erase(FindOperator(model, *state_combine_add)); + DeleteArrayIfUnused(state_forget_mul->outputs[0], model); + DeleteArrayIfUnused(state_remember_mul->outputs[0], model); + model->operators.erase(FindOperator(model, *state_forget_mul)); + model->operators.erase(FindOperator(model, *state_remember_mul)); + DeleteArrayIfUnused(state_forget_sig->outputs[0], model); + DeleteArrayIfUnused(state_info_tanh->outputs[0], model); + DeleteArrayIfUnused(state_remember_sig->outputs[0], model); + model->operators.erase(FindOperator(model, *state_forget_sig)); + model->operators.erase(FindOperator(model, *state_info_tanh)); + model->operators.erase(FindOperator(model, *state_remember_sig)); + DeleteArrayIfUnused(fc_output_split->outputs[0], model); + DeleteArrayIfUnused(fc_output_split->outputs[1], model); + DeleteArrayIfUnused(fc_output_split->outputs[2], model); + DeleteArrayIfUnused(fc_output_split->outputs[3], model); + string dims_array = fc_output_split->inputs[0]; + model->operators.erase(FindOperator(model, *fc_output_split)); + DeleteArrayIfUnused(dims_array, model); + DeleteArrayIfUnused(fully_connected->outputs[0], model); + model->operators.erase(FindOperator(model, *fully_connected)); + DeleteArrayIfUnused(concat_inputs->outputs[0], model); + model->operators.erase(FindOperator(model, *concat_inputs)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc new file mode 100644 index 0000000000..cfc77024e7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} + +bool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) { + const auto& op_array = model->GetArray(name); + if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat || + RequiredBufferSizeForShape(op_array.shape()) != 1) { + return false; + } + const auto& op_data = op_array.GetBuffer().data; + return op_data[0] == val; +} + +// Returns index of scalar input when there is exactly one scalar, -1 otherwise +int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, + float val) { + bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val); + bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val); + return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1; +} +} // namespace + +bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { + const auto maximum_it = model->operators.begin() + op_index; + const auto* maximum_op = maximum_it->get(); + if (maximum_op->type != OperatorType::kTensorFlowMaximum) { + return false; + } + CHECK_EQ(maximum_op->inputs.size(), 2); + if (maximum_op->outputs.size() != 1) { + return false; + } + int scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f); + if (scalar_input_index == -1) { + return false; + } + const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]); + if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) { + return false; + } + if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) { + return false; + } + CHECK_EQ(minimum_op->inputs.size(), 2); + + // Create and emplace Relu1 node + auto* relu1_op = new Relu1Operator; + relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]}; + relu1_op->outputs = minimum_op->outputs; + model->operators.emplace(maximum_it, relu1_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); + + // Erase Maximum scalar input & operator + model->arrays.erase(maximum_op->inputs[scalar_input_index]); + model->operators.erase(FindOperator(model, maximum_op)); + + // Erase Minimum inputs & operator + model->arrays.erase(minimum_op->inputs[0]); + model->arrays.erase(minimum_op->inputs[1]); + model->operators.erase(FindOperator(model, minimum_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc new file mode 100644 index 0000000000..d83603e9a2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// This inserts an operator whose output is a float array (name: +// flags.input_array()). It has to wait for any existing operators that +// generate this output to be removed by graph transformations. Note that there +// may be more than one operator that takes the input_array as their input, and +// that some of these may be removed by graph transformations. +bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, + GraphTransformation* transformation, + Model* model) { + // An operator with the required output may be a dequantize operator already + // created. Alternatively it may be an operator that needs to be removed + // because it is unused, in which case we wait for RemoveUnusedOp to do its + // work. + if (GetOpWithOutput(*model, input_name)) { + return false; + } + + // We only apply for the first operator if there is more than one. This is + // not strictly necessary for ordering correctness, since we insert the + // dequant operator at the beginning of the op sequence, but it makes the + // insertion more predictable (eg forward vs backwards operator sweep). + if (CountOpsWithInput(*model, input_name) > 1) { + if (op != GetFirstOpWithInput(*model, input_name)) { + return false; + } + } + + auto& input_array = model->GetArray(input_name); + if (input_array.data_type != ArrayDataType::kFloat) { + return false; + } + + if (input_array.final_data_type == input_array.data_type || + input_array.final_data_type == ArrayDataType::kNone) { + return false; + } + + const auto& dequantized_input_name = + AvailableArrayName(*model, input_name + "_dequantized"); + for (auto& other_op : model->operators) { + for (string& other_op_input : other_op->inputs) { + if (other_op_input == input_name) { + other_op_input = dequantized_input_name; + } + } + } + + auto& dequantized_input_array = + model->GetOrCreateArray(dequantized_input_name); + auto* image_input_op = new DequantizeOperator; + image_input_op->inputs = {input_name}; + image_input_op->outputs = {dequantized_input_name}; + model->operators.emplace(model->operators.begin(), image_input_op); + + CHECK(input_array.final_data_type == ArrayDataType::kUint8); + input_array.data_type = ArrayDataType::kUint8; + dequantized_input_array.data_type = ArrayDataType::kFloat; + const auto& input_minmax = input_array.GetMinMax(); + auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax(); + dequantized_input_minmax = input_minmax; + auto& input_qparams = input_array.GetOrCreateQuantizationParams(); + GetQuantizationParamsFromMinMax( + model->flags, input_minmax, &input_qparams); + + transformation->AddMessageF( + "Created %s" + " to handle quantized input image data, taking over existing" + " mean_value and std_value flags. Cleared those flags.", + LogName(*image_input_op)); + + return true; +} + +bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) { + // This is effectively a transformation applied to edges. We iterate over the + // specified node (op) and proceed for input edges. + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + bool change_made = false; + for (auto& input : op->inputs) { + for (auto& input_array : *model->flags.mutable_input_arrays()) { + if (input_array.name() == input) { + if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) { + change_made = true; + input_array.clear_mean_value(); + input_array.clear_std_value(); + } + } + } + } + return change_made; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc new file mode 100644 index 0000000000..1ff4e827aa --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +ArrayDataType CommonDataTypeOfAllInputs(const Model& model, + const Operator& op) { + CHECK_GT(op.inputs.size(), 0); + const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type; + for (const auto& input : op.inputs) { + const auto& array = model.GetArray(input); + CHECK(array.data_type == data_type) + << " Unexpected: this operator has inputs with different data types."; + } + return data_type; +} + +void SetDataTypeForAllOutputs(Model* model, Operator* op, + ArrayDataType data_type) { + for (const auto& output : op->outputs) { + model->arrays[output]->data_type = data_type; + } +} +} // namespace + +bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + // If the data type of some input is unknown, we need to yield. + for (const auto& input : op->inputs) { + if (model->arrays[input]->data_type == ArrayDataType::kNone) { + return false; + } + } + // Record data types of output before processing, so we can see at the + // end if we changed anything, and return the correct boolean value. + std::unordered_map old_output_data_types; + for (const auto& output : op->outputs) { + old_output_data_types[output] = model->arrays[output]->data_type; + } + // Do the actual output data types propagation. + if (op->type == OperatorType::kDequantize || + op->type == OperatorType::kResizeBilinear) { + // These operators unconditionally produce float outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); + } else if (op->type == OperatorType::kTensorFlowLess || + op->type == OperatorType::kTensorFlowLessEqual || + op->type == OperatorType::kTensorFlowGreater || + op->type == OperatorType::kTensorFlowGreaterEqual) { + // These operators unconditionally produce bool outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); + } else if (op->type == OperatorType::kTensorFlowShape) { + // These operators are assumed to produce int32 outputs. + SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); + } else if (op->type == OperatorType::kAveragePool || + op->type == OperatorType::kMaxPool || + op->type == OperatorType::kL2Pool || + op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected || + op->type == OperatorType::kTensorFlowMax || + op->type == OperatorType::kTensorFlowMin || + op->type == OperatorType::kPad || + op->type == OperatorType::kStridedSlice || + op->type == OperatorType::kTensorFlowReshape || + op->type == OperatorType::kSlice || + op->type == OperatorType::kSqueeze || + op->type == OperatorType::kTensorFlowSum || + op->type == OperatorType::kTensorFlowSwitch || + op->type == OperatorType::kTensorFlowTile || + op->type == OperatorType::kTensorFlowAll || + op->type == OperatorType::kReorderAxes || + op->type == OperatorType::kTensorFlowConcatV2 || + op->type == OperatorType::kFloor || + op->type == OperatorType::kGather || + op->type == OperatorType::kSpaceToBatchND || + op->type == OperatorType::kBatchToSpaceND || + op->type == OperatorType::kMean) { + // These operators produce outputs with the same type as their 1st input + CHECK_GT(op->inputs.size(), 0); + const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type; + SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kTensorFlowSplit || + op->type == OperatorType::kTensorFlowConcat) { + // These operators produce an output with the same type as their 2nd input + CHECK_GT(op->inputs.size(), 1); + const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; + SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kCast) { + // Data type of the Cast op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* cast_op = static_cast(op); + model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + } else if (op->type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported_op = static_cast(op); + if (unsupported_op->output_data_types.size() != op->outputs.size()) { + return false; + } + for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { + auto output = op->outputs[i]; + auto data_type = unsupported_op->output_data_types[i]; + model->arrays[output]->data_type = data_type; + } + } else { + // These operators produce an output with the same type as any of their + // inputs, which must always have the same type. + const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op); + SetDataTypeForAllOutputs(model, op, data_type); + } + // Return true if any output data type changed, false if none changed. + for (const auto& output : op->outputs) { + if (old_output_data_types[output] != model->arrays[output]->data_type) { + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc new file mode 100644 index 0000000000..82a43bc2ce --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -0,0 +1,1129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, + int kheight, int stride_width, int stride_height, + PaddingType padding_type, Shape* output_shape, + FixedPadding* fixed_padding) { + const int input_width = input_shape.dims(2); + const int input_height = input_shape.dims(1); + const int batch = input_shape.dims(0); + + int output_height = 0; + int output_width = 0; + if (padding_type == PaddingType::kValid) { + output_height = (input_height + stride_height - kheight) / stride_height; + output_width = (input_width + stride_width - kwidth) / stride_width; + } else if (padding_type == PaddingType::kSame) { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } else { + LOG(FATAL) << "Only supporting SAME or VALID padding"; + } + + fixed_padding->height = + ((output_height - 1) * stride_height + kheight - input_height) / 2; + fixed_padding->width = + ((output_width - 1) * stride_width + kwidth - input_width) / 2; + + // Actually had to debug a situation where those were negative due to bad + // propagation of placeholder -1 sizes in TensorFlowReshape. + CHECK_GT(output_width, 0); + CHECK_GT(output_height, 0); + output_shape->ReplaceDims({batch, output_height, output_width, output_depth}); +} + +void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, + const Shape& input_shape2, + Array* output_array) { + const int size1 = RequiredBufferSizeForShape(input_shape1); + const int size2 = RequiredBufferSizeForShape(input_shape2); + if (size1 > size2) { + output_array->copy_shape(input_shape1); + } else if (size2 > size1) { + output_array->copy_shape(input_shape2); + } else { + CHECK_EQ(size1, size2); + const int dims1 = input_shape1.dimensions_count(); + const int dims2 = input_shape2.dimensions_count(); + if (dims1 >= dims2) { + output_array->copy_shape(input_shape1); + } else { + output_array->copy_shape(input_shape2); + } + } + CHECK(output_array->has_shape()); +} + +int GetOutputDepthFromWeights(const Model& model, const Operator& op) { + const string& weights_name = op.inputs[1]; + const auto& weights_shape = model.arrays.at(weights_name)->shape(); + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected) { + return weights_shape.dims(0); + } else if (op.type == OperatorType::kDepthwiseConv) { + return weights_shape.dims(3); + } else { + LOG(FATAL) << "Unhandled operator type"; + } +} + +bool EnsureBiasVectorShape(Model* model, Operator* op) { + const string& weights_name = op->inputs[1]; + const auto& weights_array = *model->arrays[weights_name]; + // Yield until weights shape has been resolved. + if (!weights_array.has_shape()) { + return false; + } + + if (op->inputs.size() < 3) { + return false; + } + auto& bias_array = *model->arrays[op->inputs[2]]; + if (bias_array.has_shape()) { + return true; + } + + const int output_depth = GetOutputDepthFromWeights(*model, *op); + bias_array.copy_shape(Shape({output_depth})); + + auto& float_buffer = bias_array.GetMutableBuffer(); + float_buffer.data.resize(output_depth, 0); + + return true; +} + +void ProcessConvOperator(Model* model, ConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4); + + auto& output_array = model->GetArray(op->outputs[0]); + const int output_depth = weights_shape.dims(0); + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, + op->stride_height, op->padding.type, + output_array.mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); + CHECK_EQ(output_array.shape().dimensions_count(), 4); + + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const auto& output_shape = output_array.shape(); + const int input_depth = weights_shape.dims(3); + auto& im2col_array = *model->arrays[op->outputs[1]]; + im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1), + output_shape.dims(2), + input_depth * kheight * kwidth}); + } +} + +void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int input_depth = input_shape.dims(3); + const int output_depth = weights_shape.dims(3); + // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops, + // instead it has to be inferred from the weights dims. However, once we are + // here, weights dims have already been converted to our own internal format, + // where the multiplier is no longer readily apparent. So instead we get it + // as the quotient of output and input depths. We only want to do that when + // depth_multiplier had the zero value: any other value should be checked + // as done by the next if() below. + if (!op->depth_multiplier) { + op->depth_multiplier = output_depth / input_depth; + } + QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + << "input/output depths and depth_multiplier don't match"; + + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, + op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int block_size = op->block_size; + CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; + const int batch = input_shape.dims(0); + const int height = input_shape.dims(1); + const int width = input_shape.dims(2); + const int depth = input_shape.dims(3); + QCHECK_EQ(depth % (block_size * block_size), 0); + + model->GetArray(output_name) + .copy_shape(Shape({batch, height * block_size, width * block_size, + depth / block_size / block_size})); +} + +void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int block_size = op->block_size; + CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; + const int batch = input_shape.dims(0); + const int height = input_shape.dims(1); + const int width = input_shape.dims(2); + const int depth = input_shape.dims(3); + QCHECK_EQ(width % block_size, 0); + QCHECK_EQ(height % block_size, 0); + + model->GetArray(output_name) + .copy_shape(Shape({batch, height / block_size, width / block_size, + depth * block_size * block_size})); +} + +void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_GE(input_shape.dimensions_count(), 1); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + + const int weights_output_depth = weights_shape.dims(0); + CHECK_EQ(weights_shape.dimensions_count(), 2); + + const int input_overall_size = RequiredBufferSizeForShape(input_shape); + const int matmul_repeats = input_overall_size / weights_shape.dims(1); + CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size); + + auto& output_array = model->GetArray(op->outputs[0]); + output_array.copy_shape(Shape({matmul_repeats, weights_output_depth})); +} + +void ProcessTensorFlowReshapeOperator(Model* model, + TensorFlowReshapeOperator* op) { + auto& output_array = *model->arrays[op->outputs[0]]; + // Bail if we already have output dims + if (output_array.has_shape()) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + + const string& shape_name = op->inputs[1]; + auto& shape_array = model->GetArray(shape_name); + // Yield until the shape is resolved as a constant array + if (!shape_array.buffer) { + return; + } + CHECK(shape_array.data_type == ArrayDataType::kInt32); + // shape_data is the raw array of ints describing the shape + // in the TensorFlow node. We intentionally make a copy here, rather than + // modify wildcards in-place below, because in some graphs, the same shape + // array with a wildcard may be referenced from multiple Reshape nodes, where + // the wildcard needs to resolved to distinct values. + std::vector shape_data = + shape_array.GetBuffer().data; + // The Reshape shape may have a wildcard dim, encoded as -1. + bool has_wildcard = false; + int wildcard_index = 0; + int product_non_wildcard_dims = 1; + for (int i = 0; i < shape_data.size(); i++) { + if (shape_data[i] == -1) { + CHECK(!has_wildcard); + has_wildcard = true; + wildcard_index = i; + } else { + product_non_wildcard_dims *= shape_data[i]; + } + } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + if (has_wildcard) { + shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; + } + auto& output_shape = *output_array.mutable_shape(); + *output_shape.mutable_dims() = shape_data; + const int output_flat_size = RequiredBufferSizeForShape(output_shape); + CHECK_EQ(output_flat_size, input_flat_size); +} + +void ProcessSimpleOperator(Model* model, Operator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + + const string& output_name = op->outputs[0]; + auto& output_array = *model->arrays[output_name]; + if (output_array.has_shape()) { + return; + } + + output_array.copy_shape(input_array.shape()); +} + +void ProcessSimpleBinaryOperator(Model* model, Operator* op) { + CHECK_EQ(op->inputs.size(), 2); + const auto& input0_array = *model->arrays[op->inputs[0]]; + const auto& input1_array = *model->arrays[op->inputs[1]]; + // Yield until input dims have been resolved. + if (!input0_array.has_shape() || !input1_array.has_shape()) { + return; + } + const string& output_name = op->outputs[0]; + auto& output_array = *model->arrays[output_name]; + ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(), + &output_array); +} + +void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { + CHECK_LE(op->inputs.size(), 2); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + return; + } + if (op->inputs.size() == 2) { + // There is a reduction_indices input. + const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& reduction_array = *model->arrays[op->inputs[1]]; + if (!reduction_array.buffer) { + return; + } + if (!input_array.has_shape()) { + return; + } + auto& input_shape = input_array.shape(); + CHECK(reduction_array.buffer->type == ArrayDataType::kInt32); + const auto& reduction_array_vals = + reduction_array.GetBuffer().data; + auto& output_dims = *output_array.mutable_shape()->mutable_dims(); + output_dims.clear(); + for (int i = 0; i < input_shape.dimensions_count(); i++) { + bool is_reduction_dim = false; + for (int r : reduction_array_vals) { + if (i == r) { + is_reduction_dim = true; + } + } + if (!is_reduction_dim) { + output_dims.push_back(input_shape.dims(i)); + } + } + } else { + // No reduction_indices means complete reduction to a single scalar. + output_array.copy_shape(Shape({})); + } +} + +void ProcessSliceOperator(Model* model, SliceOperator* op) { + CHECK_EQ(op->inputs.size(), 3); + CHECK_EQ(op->outputs.size(), 1); + + // Yield until the Slice params have been resolved. + if (op->begin.empty()) return; + + // Yield until input dims have been resolved. + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) return; + const Shape& input_shape = input_array.shape(); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + CHECK_EQ(input_shape.dims().size(), op->size.size()); + CHECK_EQ(op->begin.size(), op->size.size()); + + std::vector output_dims; + for (int i = 0; i < op->begin.size(); ++i) { + int size = op->size[i]; + if (size == -1) { + size = input_array.shape().dims(i) - op->begin[i]; + } + output_dims.push_back(size); + } + + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const string& output_name = op->outputs[0]; + Shape* output_shape = model->GetArray(output_name).mutable_shape(); + ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order, + output_shape); +} + +void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { + // Yield until input dims have been resolved. + for (const auto& input_name : op->inputs) { + auto& input_array = *model->arrays[input_name]; + if (!input_array.has_shape()) { + return; + } + } + auto& output_array = model->GetArray(op->outputs[0]); + // Use 0 input as basis for output dimensions. + const auto& first_input_array = *model->arrays[op->inputs[0]]; + output_array.copy_shape(first_input_array.shape()); + // Determine the concat size, and enfore that all inputs have + // the same dimensions count. + int concat_size = 0; + for (const auto& input_name : op->inputs) { + auto& input_array = *model->arrays[input_name]; + CHECK(input_array.has_shape()); + if (input_array.shape().dimensions_count() == 0) { + continue; + } + CHECK_EQ(input_array.shape().dimensions_count(), + output_array.shape().dimensions_count()); + const std::vector& input_dims = input_array.shape().dims(); + CHECK_LT(op->concat_dim, input_dims.size()); + concat_size += input_dims[op->concat_dim]; + } + // Write out the concat_size on the output array shape. + auto& output_shape = *output_array.mutable_shape(); + auto& output_dims = *output_shape.mutable_dims(); + CHECK_LT(op->concat_dim, output_shape.dimensions_count()); + output_dims[op->concat_dim] = concat_size; +} + +void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + const string& input_name = op->inputs[1]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const Shape& input_shape = input_array.shape(); + + // This code is slightly suspect. The TensorFlow docs say that the axis + // selection defaults to 0, but we are splitting across the final axis. + const int input_dims_count = input_shape.dimensions_count(); + const int input_depth = input_shape.dims(input_dims_count - 1); + CHECK_EQ(input_depth % op->num_split, 0); + const int split_depth = input_depth / op->num_split; + + Shape output_shape = input_shape; + (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth; + + CHECK_EQ(op->outputs.size(), op->num_split); + for (const auto& output : op->outputs) { + model->arrays[output]->copy_shape(output_shape); + } +} + +void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + if (input_shape.dimensions_count() < 4) { + LOG(FATAL) << "missing dimensions for " << input_name; + } + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + if (!model->arrays[op->inputs[0]]->has_shape() || + !model->arrays[op->inputs[1]]->has_shape()) { + return; + } + const auto& input_data_shape = model->arrays[op->inputs[0]]->shape(); + + const string& output_size_name = op->inputs[1]; + const auto& output_size_array = *model->arrays[output_size_name]; + CHECK(output_size_array.data_type == ArrayDataType::kInt32); + CHECK(output_size_array.has_shape()); + const auto& output_size_shape = output_size_array.shape(); + CHECK_EQ(output_size_shape.dimensions_count(), 1); + CHECK_EQ(output_size_shape.dims(0), 2); + std::vector output_shape = + output_size_array.GetBuffer().data; + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_data_shape.dims(0), output_shape[0], output_shape[1], + input_data_shape.dims(3)})); +} + +void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { + // I/O arrays should be allocated on creation of op. + QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); + QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + + const auto& input_array = + *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]]; + // Yield until all input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_GE(input_shape.dimensions_count(), 2); + + const auto& prev_activ_array = + *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]]; + // Yield until all input dims have been resolved. + if (!prev_activ_array.has_shape()) { + return; + } + const auto& prev_activ_shape = prev_activ_array.shape(); + CHECK_GE(prev_activ_shape.dimensions_count(), 2); + + const auto& weights_array = + *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 2); + + const auto& bias_array = + *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]]; + // Yield until bias dims have been resolved. + if (!bias_array.has_shape()) { + return; + } + const auto& bias_shape = bias_array.shape(); + CHECK_GE(bias_shape.dimensions_count(), 1); + + const auto& prev_state_array = + *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]]; + // Yield until all input dims have been resolved. + if (!prev_state_array.has_shape()) { + return; + } + const auto& prev_state_shape = prev_state_array.shape(); + CHECK_GE(prev_state_shape.dimensions_count(), 2); + + const int fc_output_depth = weights_shape.dims(0); + CHECK_EQ(fc_output_depth, bias_shape.dims(0)); + CHECK_EQ(fc_output_depth % 4, 0); + const int depth = fc_output_depth / 4; + + const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1); + const int fc_input_depth = weights_shape.dims(1); + CHECK_EQ(input_depth + depth, fc_input_depth); + Shape output_shape(input_shape); + (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth; + + // Set output dimensions + model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT]) + .copy_shape(output_shape); + model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .copy_shape(output_shape); + + Shape concat_temp_shape(input_shape); + (*concat_temp_shape + .mutable_dims())[concat_temp_shape.dimensions_count() - 1] = + fc_input_depth; + model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]) + .copy_shape(concat_temp_shape); + + Shape activ_temp_shape(input_shape); + (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] = + fc_output_depth; + model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]) + .copy_shape(activ_temp_shape); +} + +void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const auto input_height = input_shape.dims(1); + const auto input_width = input_shape.dims(2); + + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& paddings_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array_shape = block_shape_array.shape(); + const auto& paddings_array_shape = paddings_array.shape(); + QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); + QCHECK_EQ(paddings_array_shape.dimensions_count(), 2); + + // We only support two dimensions. + QCHECK_EQ(block_shape_array_shape.dims(0), 2); + if (!block_shape_array.buffer) { + return; + } + QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); + const auto& block_shape_data = + block_shape_array.GetBuffer().data; + auto block_height = block_shape_data[0]; + auto block_width = block_shape_data[1]; + + QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions + QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension. + if (!paddings_array.buffer) { + return; + } + QCHECK(paddings_array.data_type == ArrayDataType::kInt32); + const auto& paddings_data = + paddings_array.GetBuffer().data; + int height_with_paddings = input_height + paddings_data[0] + paddings_data[1]; + int width_with_paddings = input_width + paddings_data[2] + paddings_data[3]; + QCHECK_EQ(height_with_paddings % block_height, 0); + QCHECK_EQ(width_with_paddings % block_width, 0); + int output_height = height_with_paddings / block_height; + int output_width = width_with_paddings / block_width; + + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_shape.dims(0) * block_height * block_width, output_height, + output_width, input_shape.dims(3)})); +} + +void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const auto input_height = input_shape.dims(1); + const auto input_width = input_shape.dims(2); + + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array_shape = block_shape_array.shape(); + const auto& crops_array_shape = crops_array.shape(); + QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); + QCHECK_EQ(crops_array_shape.dimensions_count(), 2); + + // We only support two dimensions. + QCHECK_EQ(block_shape_array_shape.dims(0), 2); + if (!block_shape_array.buffer) { + return; + } + QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); + const auto& block_shape_data = + block_shape_array.GetBuffer().data; + auto block_height = block_shape_data[0]; + auto block_width = block_shape_data[1]; + + QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions + QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension. + if (!crops_array.buffer) { + return; + } + QCHECK(crops_array.data_type == ArrayDataType::kInt32); + const auto& crops_data = crops_array.GetBuffer().data; + // We don't support crops now. + QCHECK_EQ(crops_data[0], 0); + QCHECK_EQ(crops_data[1], 0); + QCHECK_EQ(crops_data[2], 0); + QCHECK_EQ(crops_data[3], 0); + + QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0); + + int output_height = input_height * block_height; + int output_width = input_width * block_width; + + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_shape.dims(0) / (block_height * block_width), output_height, + output_width, input_shape.dims(3)})); +} + +void ProcessGatherOperator(Model* model, GatherOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& indices_array = *model->arrays[op->inputs[1]]; + auto& output_array = *model->arrays[op->outputs[0]]; + + // Bail if we already know the output shape. + if (output_array.has_shape()) { + return; + } + + // Yield until input dims have been resolved. + if (!input_array.has_shape() || !indices_array.has_shape()) { + return; + } + + const auto& input_shape = input_array.shape(); + const auto& indices_shape = indices_array.shape(); + QCHECK_GE(input_shape.dimensions_count(), 1); + op->input_rank = input_shape.dimensions_count(); + + // We only support 1-D indices. + QCHECK_EQ(indices_shape.dimensions_count(), 1); + + // Copy the input dimensions to the output except for dimension 0, + // where the dimension of indices_shape is used. + auto output_dims = output_array.mutable_shape()->mutable_dims(); + output_dims->push_back(indices_shape.dims(0)); + for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + output_dims->push_back(input_shape.dims(dim)); + } +} + +void ProcessPadOperator(Model* model, PadOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->left_padding.empty()) return; + CHECK_EQ(op->left_padding.size(), op->right_padding.size()); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->left_padding.size(), dims.size()); + + for (int i = 0; i < op->left_padding.size(); ++i) { + dims[i] += op->left_padding[i] + op->right_padding[i]; + } + + output_array.copy_shape(output_shape); +} + +void ProcessMeanOperator(Model* model, MeanOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + const std::vector& indices = op->reduction_indices; + if (indices.empty()) return; + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + const std::vector& input_dims = input_array.shape().dims(); + std::vector output_dims; + for (int i = 0; i < input_dims.size(); ++i) { + if (std::find(indices.begin(), indices.end(), i) == indices.end()) { + output_dims.push_back(input_dims[i]); + } + } + CHECK(!output_dims.empty()); + CHECK_EQ(output_dims.size(), 2); + + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->start_indices.empty()) return; + CHECK_EQ(op->start_indices.size(), op->stop_indices.size()); + CHECK_EQ(op->start_indices.size(), op->strides.size()); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->start_indices.size(), dims.size()); + + for (int i = 0; i < op->start_indices.size(); ++i) { + const int mask = 1 << i; + const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i]; + const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i] + : op->stop_indices[i]; + dims[i] = (stop - start) / op->strides[i]; + } + + output_array.copy_shape(output_shape); +} + +void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { + CHECK_EQ(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + const std::vector& input_dims = input_array.shape().dims(); + std::vector output_dims; + + for (int i = 0; i < input_dims.size(); ++i) { + if (input_dims[i] != 1 || + (!op->squeeze_dims.empty() && + std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) == + op->squeeze_dims.end())) { + output_dims.push_back(input_dims[i]); + } + } + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessSvdfOperator(Model* model, SvdfOperator* op) { + CHECK(op->inputs.size() == 3 || op->inputs.size() == 4); + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) return; + + auto& weights_feature_array = *model->arrays[op->inputs[1]]; + if (!weights_feature_array.has_shape()) return; + + const auto& weights_time_array = *model->arrays[op->inputs[2]]; + if (!weights_time_array.has_shape()) return; + + const bool has_bias = (op->inputs.size() == 4); + if (has_bias) { + const auto& bias_array = *model->arrays[op->inputs[3]]; + if (!bias_array.has_shape()) return; + } + + const int batch_size = input_array.shape().dims()[0]; + const int num_units = weights_feature_array.shape().dims()[0]; + const int memory_size = weights_time_array.shape().dims()[1]; + + auto& state_array = model->GetArray(op->outputs[0]); + state_array.mutable_shape()->ReplaceDims( + {batch_size, memory_size * num_units}); + + auto& output_array = model->GetArray(op->outputs[1]); + output_array.mutable_shape()->ReplaceDims({batch_size, num_units}); +} +} // namespace + +bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + std::unordered_map> old_output_dims; + for (const auto& output : op->outputs) { + if (model->arrays[output]->has_shape()) { + old_output_dims[output] = model->arrays[output]->shape().dims(); + } + } + + switch (op->type) { + case OperatorType::kBatchNormalization: + case OperatorType::kL2Normalization: + case OperatorType::kDequantize: + case OperatorType::kRelu: + case OperatorType::kRelu1: + case OperatorType::kRelu6: + case OperatorType::kSoftmax: + case OperatorType::kLogistic: + case OperatorType::kTanh: + case OperatorType::kLocalResponseNormalization: + case OperatorType::kTensorFlowIdentity: + case OperatorType::kFakeQuant: + case OperatorType::kTensorFlowRsqrt: + case OperatorType::kTensorFlowSqrt: + case OperatorType::kTensorFlowSquare: + case OperatorType::kTensorFlowAll: + case OperatorType::kTensorFlowAssert: + case OperatorType::kCast: + case OperatorType::kFloor: + ProcessSimpleOperator(model, op); + break; + case OperatorType::kGather: + ProcessGatherOperator(model, static_cast(op)); + break; + + case OperatorType::kAdd: + case OperatorType::kSub: + case OperatorType::kMul: + case OperatorType::kDiv: + case OperatorType::kTensorFlowLess: + case OperatorType::kTensorFlowLessEqual: + case OperatorType::kTensorFlowGreater: + case OperatorType::kTensorFlowMaximum: + case OperatorType::kTensorFlowMinimum: + case OperatorType::kTensorFlowGreaterEqual: + ProcessSimpleBinaryOperator(model, op); + break; + case OperatorType::kConv: + ProcessConvOperator(model, static_cast(op)); + break; + case OperatorType::kDepthwiseConv: + ProcessDepthwiseConvOperator(model, + static_cast(op)); + break; + case OperatorType::kDepthToSpace: + ProcessDepthToSpaceOperator(model, + static_cast(op)); + break; + case OperatorType::kSpaceToDepth: + ProcessSpaceToDepthOperator(model, + static_cast(op)); + break; + case OperatorType::kFullyConnected: + ProcessFullyConnectedOperator(model, + static_cast(op)); + break; + case OperatorType::kTensorFlowReshape: + ProcessTensorFlowReshapeOperator( + model, static_cast(op)); + break; + case OperatorType::kAveragePool: + ProcessAveragePoolOperator(model, static_cast(op)); + break; + case OperatorType::kMaxPool: + ProcessMaxPoolOperator(model, static_cast(op)); + break; + case OperatorType::kL2Pool: + ProcessL2PoolOperator(model, static_cast(op)); + break; + case OperatorType::kTensorFlowMin: + case OperatorType::kTensorFlowMax: + case OperatorType::kTensorFlowSum: + ProcessTensorFlowReductionOperator(model, op); + break; + + case OperatorType::kSlice: + ProcessSliceOperator(model, static_cast(op)); + break; + + case OperatorType::kTensorFlowTile: + // We don't currently implement the propagation of fixed sizes through + // a TensorFlow Tile. + // + // Fortunately, we don't need to: so far, we have only dealt with Tile + // or Slice ops in subgraphs that are identified as L2Normalization. + // See IdentifyL2Normalization. + break; + case OperatorType::kTensorFlowSwitch: + // We can't know the sizes of the outputs until we have resolved the + // predicate, and once we have resolved the predicate, the whole + // Switch node will get resolved away. + // See ResolveTensorFlowSwitch. + break; + case OperatorType::kTensorFlowMerge: + // No need to bother resolving TensorFlow Merge ops: other graph + // transformations will remove them anyway. + // See ResolveTensorFlowMerge. + break; + case OperatorType::kTensorFlowSplit: + ProcessTensorFlowSplitOperator(model, + static_cast(op)); + break; + case OperatorType::kSqueeze: + ProcessSqueezeOperator(model, static_cast(op)); + break; + case OperatorType::kTensorFlowConcat: + case OperatorType::kTensorFlowConcatV2: + // Unimplemented, hopefully another graph transformation will + // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat + // will resolve this node to a DepthConcatenation, or else we have + // a more general non-depth concatenation that will hopefully be dropped, + // or else at the moment we will abort. + break; + case OperatorType::kTensorFlowShape: + // Unimplemented, hopefully another graph transformation will drop it or + // rewrite it. + break; + case OperatorType::kReorderAxes: + ProcessReorderAxesOperator(model, static_cast(op)); + break; + case OperatorType::kConcatenation: + ProcessConcatenationOperator(model, + static_cast(op)); + break; + case OperatorType::kResizeBilinear: + ProcessResizeBilinearOperator(model, + static_cast(op)); + break; + case OperatorType::kLstmCell: + ProcessLstmCellOperator(model, static_cast(op)); + break; + case OperatorType::kTensorFlowMatMul: + // MatMul operators are converted to FullyConnected, after which their + // shapes are propagated. + break; + case OperatorType::kSpaceToBatchND: + ProcessSpaceToBatchNDOperator(model, + static_cast(op)); + break; + case OperatorType::kBatchToSpaceND: + ProcessBatchToSpaceNDOperator(model, + static_cast(op)); + break; + case OperatorType::kPad: + ProcessPadOperator(model, static_cast(op)); + break; + case OperatorType::kMean: + ProcessMeanOperator(model, static_cast(op)); + break; + case OperatorType::kStridedSlice: + ProcessStridedSliceOperator(model, + static_cast(op)); + break; + case OperatorType::kTensorFlowUnsupported: + break; + case OperatorType::kSvdf: + ProcessSvdfOperator(model, static_cast(op)); + break; + default: + // Unimplemented, another graph transformation should drop it. + LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); + } + + // Return true if any output dim changed, false if none changed. + // Assumption: no transformation clears an output shape, they only add shapes. + for (const auto& output : op->outputs) { + if (model->arrays[output]->has_shape() && + (old_output_dims[output] != model->arrays[output]->shape().dims())) { + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc new file mode 100644 index 0000000000..5551755ea7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -0,0 +1,467 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool SupportsQuantization(const Operator& op) { + auto type = op.type; + if (type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported = static_cast(&op); + return unsupported->quantized; + } + return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv || + type == OperatorType::kFullyConnected || + type == OperatorType::kConcatenation || + type == OperatorType::kL2Normalization || type == OperatorType::kAdd || + type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || + type == OperatorType::kLogistic || type == OperatorType::kSoftmax || + type == OperatorType::kTensorFlowReshape || + type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || + type == OperatorType::kDepthToSpace; +} + +template +std::unique_ptr QuantizeBuffer( + const GenericBuffer& buffer, + const QuantizationParams& quantization_params) { + const auto inverse_scale = 1. / quantization_params.scale; + CHECK(buffer.type == ArrayDataType::kFloat); + const auto& float_buffer = + static_cast&>(buffer); + auto* quantized_buffer = new Buffer; + quantized_buffer->data.resize(float_buffer.data.size()); + const auto qmin = static_cast(std::numeric_limits>::min()); + const auto qmax = static_cast(std::numeric_limits>::max()); + for (std::size_t i = 0; i < float_buffer.data.size(); i++) { + const float src_val = float_buffer.data[i]; + double scaled_val; // Astonishingly, using 'float' degrades accuracy just + // enough to make a few tests fail! + if (quantization_params.scale == 0) { + CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, " + << "so all its values should be 0."; + scaled_val = quantization_params.zero_point; + } else { + scaled_val = quantization_params.zero_point + inverse_scale * src_val; + } + const auto rounded_val = static_cast(std::round(scaled_val)); + const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val)); + quantized_buffer->data[i] = static_cast>(clamped_val); + } + return std::unique_ptr(quantized_buffer); +} + +template +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, + const QuantizationParams& quantization_params) { + auto& array = model->GetArray(name); + CHECK(array.data_type == ArrayDataType::kFloat); + CHECK(!array.quantization_params); + array.GetOrCreateQuantizationParams() = quantization_params; + if (array.buffer) { + array.buffer = QuantizeBuffer(*array.buffer, quantization_params); + } + array.data_type = A; + transformation->AddMessageF("Quantized array %s", name); +} + +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params) { + switch (quantized_data_type) { + case ArrayDataType::kUint8: + return QuantizeArray(transformation, model, name, + quantization_params); + case ArrayDataType::kInt32: + return QuantizeArray(transformation, model, name, + quantization_params); + default: + LOG(FATAL) << "Unhandled case."; + } +} + +const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { + auto& array = model->GetArray(array_name); + // Normally we should have a MinMax recorded on this Array, + // so we just use it. + if (array.minmax != nullptr) { + return *array.minmax; + } + + // We don't have a MinMax. That's bad news: we need + // the graph to provide MinMax info for all arrays in order + // for inference to reproduce faithfully the same quantization + // error as the training process had. + // + // But we still want to support a fallback for constant arrays, + // just using the plain min and max computed from array elements. + // We should hopefully never rely on that in production, as that + // will not give very good accuracy as that typically won't be + // exactly what the training process used. But it will be useful + // to allow easily trying out quantization even if the graph + // lacks some minmax information. + if (array.buffer != nullptr) { + LOG(WARNING) + << "Constant array " << array_name + << " lacks MinMax information. To make up for that, we will now compute" + << " the MinMax from actual array elements. That will result in" + << " quantization parameters that probably do not match whichever " + "arithmetic" + << " was used during training, and thus will probably be a cause of " + "poor" + << " inference accuracy."; + CHECK(array.buffer->type == ArrayDataType::kFloat); + const auto& data = array.GetBuffer().data; + // We always want [min, max] to contain 0. + float min = 0.f; + float max = 0.f; + for (auto val : data) { + min = std::min(min, val); + max = std::max(max, val); + } + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = min; + minmax.max = max; + return minmax; + } + + LOG(FATAL) << "Array " << array_name + << " does not have MinMax information, " + "and is not a constant array. Cannot " + "proceed with quantization."; +} + +bool ChooseQuantizationForOperatorInput( + GraphTransformation* transformation, Model* model, const Operator& op, + std::size_t input_index, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + const auto& input = op.inputs[input_index]; + auto& array = model->GetArray(input); + if (array.data_type != ArrayDataType::kFloat) { + return false; + } + if (op.type == OperatorType::kConv || + op.type == OperatorType::kDepthwiseConv || + op.type == OperatorType::kFullyConnected) { + if (input_index == 2) { + // Quantization of bias vector. + // We need both of the mandatory inputs (input activations and weights) to + // have + // been already quantized. + const auto& input_activations = model->GetArray(op.inputs[0]); + const auto& input_weights = model->GetArray(op.inputs[1]); + if (!input_activations.quantization_params || + !input_weights.quantization_params) { + return false; + } + const auto input_activations_scale = + input_activations.quantization_params->scale; + const auto input_weights_scale = input_weights.quantization_params->scale; + quantization_params->scale = + input_activations_scale * input_weights_scale; + quantization_params->zero_point = 0; + *quantized_data_type = ArrayDataType::kInt32; + transformation->AddMessageF( + "Input array %s is a bias vector. Choosing quantization params " + "accordingly.", + input); + return true; + } + } + + const MinMax& minmax = GetOrComputeMinMax(model, input); + GetQuantizationParamsFromMinMax(model->flags, minmax, + quantization_params); + transformation->AddMessageF( + "For input array %s with min=%g" + ", max=%g" + ", chose to quantize as uint8 with zero_point=%d" + ", scale=%g", + input, minmax.min, minmax.max, quantization_params->zero_point, + quantization_params->scale); + *quantized_data_type = ArrayDataType::kUint8; + return true; +} + +bool IsExactlyRepresentable(double real_value, ArrayDataType data_type, + const QuantizationParams& quantization_params) { + const double scaled_value = + quantization_params.zero_point + real_value / quantization_params.scale; + const double fractional_scaled_value = + scaled_value - std::round(scaled_value); + if (std::abs(fractional_scaled_value) > 1e-12) { + return false; + } + const double rounded_scaled_value = std::round(scaled_value); + if (data_type == ArrayDataType::kUint8) { + if (rounded_scaled_value < 0 || rounded_scaled_value > 255) { + return false; + } + } + return true; +} + +bool ChooseHardcodedQuantizationForOperatorOutput( + const Operator& op, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + if (op.type == OperatorType::kL2Normalization) { + // L2Normalization has range: [-1, 1]. + // 0 should be exactly representable, as values will typically be centered + // around 0, with many values near 0. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 128; + quantization_params->scale = 1. / 128.; + CHECK( + IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); + return true; + } + if ((op.type == OperatorType::kLogistic) || + (op.type == OperatorType::kSoftmax)) { + // Logistic and Softmax have range: [0, 1]. + // + // For Logistic, 0.5 should be exactly representable, as implementations + // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and + // the glueing of the two halves of the graph will only be seamless if we + // are accurately representing logistic(0) == 0.5. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 0; + quantization_params->scale = 1. / 256.; + CHECK(IsExactlyRepresentable(0.5, *quantized_data_type, + *quantization_params)); + return true; + } + return false; +} + +bool ChooseQuantizationForOperatorOutput( + GraphTransformation* transformation, Model* model, const Operator& op, + std::size_t output_index, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + const auto& output = op.outputs[output_index]; + auto& array = model->GetArray(output); + if (array.data_type != ArrayDataType::kFloat) { + return false; + } + if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type, + quantization_params)) { + transformation->AddMessageF( + "Output array %s is produced by a %s operator. Choosing fixed " + "quantization params accordingly.", + output, OperatorTypeName(op.type)); + return true; + } + if ((op.type == OperatorType::kDepthToSpace) || + (op.type == OperatorType::kSpaceToDepth)) { + // DepthToSpace and SpaceToDepth should preserve the quantization parameters + // of the input array, as these are simple reshape operations. + const auto& input_quantization_params = + model->GetArray(op.inputs[0]).GetQuantizationParams(); + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = input_quantization_params.zero_point; + quantization_params->scale = input_quantization_params.scale; + + transformation->AddMessageF( + "Output array %s is produced by a %s operator. Copying quantization " + "params from input array.", + output, OperatorTypeName(op.type)); + return true; + } + const MinMax& minmax = GetOrComputeMinMax(model, output); + GetQuantizationParamsFromMinMax(model->flags, minmax, + quantization_params); + *quantized_data_type = ArrayDataType::kUint8; + transformation->AddMessageF( + "For output array %s with min=%g, max=%g" + ", chose to quantize as uint8 with zero_point=%d" + ", scale=%g", + output, minmax.min, minmax.max, quantization_params->zero_point, + quantization_params->scale); + + return true; +} +} // namespace + +bool Quantize::Run(Model* model, std::size_t op_index) { + // Our general "quantization" graph transformation consists in replacing + // QuantizedInputArrays[] -> + // DequantizeOperators[] -> + // FloatInputArrays[] -> + // Operator -> + // FloatOutputArray + // by + // QuantizedInputArrays[] -> + // Operator -> + // QuantizedOutputArray -> + // DequantizeOperator -> + // FloatOutputArray + // + // In other words, this is pushing Dequantize operators to the right of + // other operators. + // + + auto& op = *model->operators[op_index]; + if (op.type == OperatorType::kDequantize || + op.type == OperatorType::kFakeQuant) { + return false; + } + + // Our assumption here is that the input arrays are already quantized - + // that is typically the case in models operating on an input bitmap + // image, and MakeInitialDequantizeOp should have already resolved + // the handling of the input image as an initial Dequantize op. + // + // Thus we are building around the assumption that the graph always starts + // with a quantized input array, and only after some Dequantize op do we have + // float arrays. The problem of quantizing the graph thus becomes a problem of + // pushing Dequantize ops to the right of other ops. + // + // Let us just guard this assumption by the following assertion: + for (const auto& input : op.inputs) { + if (IsInputArray(*model, input)) { + const auto& input_array = model->GetArray(input); + CHECK(input_array.quantization_params); + } + } + if (!SupportsQuantization(op)) { + LOG(FATAL) << "Unimplemented: this graph contains an operator of type " + << HelpfulOperatorTypeName(op) + << " for which the quantized form is not yet implemented. " + "Sorry, and patches welcome (that's a relatively fun patch " + "to write, mostly providing the actual quantized arithmetic " + "code for this op)."; + } + + for (const auto& input : op.inputs) { + const auto& array = model->GetArray(input); + if (array.data_type == ArrayDataType::kFloat) { + if (!array.minmax && !array.buffer) { + LOG(ERROR) << "Can't quantize input array " << input + << " because it lacks min/max info"; + return false; + } + const auto* other_op = GetOpWithOutput(*model, input); + if (other_op && other_op->type != OperatorType::kDequantize) { + AddMessageF( + "Not quantizing %s for now, because its input array %s is not " + "produced by a Dequantize op, " + "which means that we should yield and let other ops " + "get quantized first", + LogName(op), input); + return false; + } + } + } + + bool changed = false; + + // Quantize inputs, remove any Dequantize op on the inputs side + for (std::size_t input_index = 0; input_index < op.inputs.size(); + input_index++) { + ArrayDataType quantized_data_type; + QuantizationParams quantization_params; + if (ChooseQuantizationForOperatorInput(this, model, op, input_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& input = op.inputs[input_index]; + if (IsConstantParameterArray(*model, input)) { + QuantizeArray(this, model, input, quantized_data_type, + quantization_params); + } else { + auto dequantize_it = FindOpWithOutput(*model, input); + CHECK(dequantize_it != model->operators.end()); + auto* dequantize_op = dequantize_it->get(); + CHECK(dequantize_op->type == OperatorType::kDequantize); + op.inputs[input_index] = dequantize_op->inputs[0]; + // Check if the output of that Dequantize op was not used by any + // other operator. We will then erase that Dequantize op. + if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + model->arrays.erase(dequantize_op->outputs[0]); + model->operators.erase(dequantize_it); + } + } + } + } + + // Quantize outputs, add Dequantize ops as needed on the outputs side + for (std::size_t output_index = 0; output_index < op.outputs.size(); + output_index++) { + ArrayDataType quantized_data_type; + QuantizationParams quantization_params; + if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& output = op.outputs[output_index]; + QuantizeArray(this, model, output, quantized_data_type, + quantization_params); + const auto& dequantized_output = + AvailableArrayName(*model, output + "_dequantized"); + const auto& output_array = model->GetArray(output); + const auto& output_minmax = output_array.GetMinMax(); + auto& dequantized_output_array = + model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + auto& dequantized_output_minmax = + dequantized_output_array.GetOrCreateMinMax(); + dequantized_output_minmax.min = output_minmax.min; + dequantized_output_minmax.max = output_minmax.max; + for (const auto& other_op : model->operators) { + for (auto& other_op_input : other_op->inputs) { + if (other_op_input == output) { + other_op_input = dequantized_output; + } + } + } + auto* dequantize_op = new DequantizeOperator; + dequantize_op->inputs = {output}; + dequantize_op->outputs = {dequantized_output}; + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == output) { + model->flags.set_output_arrays(i, dequantized_output); + } + } + const auto op_it = FindOp(*model, &op); + model->operators.emplace(op_it + 1, dequantize_op); + } + } + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc new file mode 100644 index 0000000000..371ced388a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model, + const MinMax& minmax, const string& array_name) { + auto& annotated_array = model->GetArray(array_name); + if (annotated_array.minmax) { + return false; + } + annotated_array.GetOrCreateMinMax() = minmax; + transformation->AddMessageF( + "Read min/max annotation for array %s: min=%g, max=%g", array_name, + minmax.min, minmax.max); + return true; +} + +} // end namespace + +bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast(fakequant_base_op); + + bool changed = false; + + if (!fakequant_op->minmax) { + CHECK_EQ(fakequant_op->inputs.size(), 3); + // We need to yield until the min and max parameters have been + // resolved to constant arrays. + for (int i = 1; i <= 2; i++) { + if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) { + return false; + } + } + + // Obtain the final min/max values + const auto& min_array = model->GetArray(fakequant_op->inputs[1]); + const auto& max_array = model->GetArray(fakequant_op->inputs[2]); + CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); + CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); + fakequant_op->minmax.reset(new MinMax); + MinMax& minmax = *fakequant_op->minmax; + minmax.min = min_array.GetBuffer().data[0]; + minmax.max = max_array.GetBuffer().data[0]; + // We always want [min, max] to contain 0. + minmax.min = std::min(minmax.min, 0.); + minmax.max = std::max(minmax.max, 0.); + + // We won't use the input arrays that provided these min and max + // values, anymore. Delete them unless they are used by something + // else. + for (int i = 1; i <= 2; i++) { + if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { + model->arrays.erase(fakequant_op->inputs[i]); + } + } + fakequant_op->inputs.resize(1); + changed = true; + } + + // At this point, this FakeQuantOperator should have a MinMax + // attached to it, and should only have 1 input (it should not have + // 2nd and 3rd input arrays giving min and max anymore). + CHECK(fakequant_op->minmax); + CHECK_EQ(1, fakequant_op->inputs.size()); + + const MinMax& minmax = *fakequant_op->minmax; + + // Record the MinMax info on the input and output arrays + changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]); + changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]); + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc new file mode 100644 index 0000000000..3992e7d1ef --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { + const auto dequantize_it = model->operators.begin() + op_index; + const auto* dequantize_op = dequantize_it->get(); + if (dequantize_op->type != OperatorType::kDequantize) { + return false; + } + const auto& output = dequantize_op->outputs[0]; + // We can remove any dequantize op whose output is not consumed by + // any op. This is not necessarily equivalent to the output being + // one of the model's output arrays, as some intermediate array + // in the middle of the graph might be designated as an output + // array. + if (CountOpsWithInput(*model, output)) { + return false; + } + + // If one of the model's output arrays was actually the Dequantize op's + // output, then we need to update it to point to the Dequantize op's input. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (output == model->flags.output_arrays(i)) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + + // Remove the node and its output array. + AddMessageF("Removed final %s", LogName(*dequantize_op)); + model->arrays.erase(output); + model->operators.erase(dequantize_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc new file mode 100644 index 0000000000..35a0c46532 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { + const auto assert_it = model->operators.begin() + op_index; + const auto* assert_op = assert_it->get(); + if (assert_op->type != OperatorType::kTensorFlowAssert) { + return false; + } + + bool changed = false; + // Remove any other node's dependency on this assert node + for (const auto& op : model->operators) { + auto it = op->inputs.begin(); + while (it != op->inputs.end()) { + if (*it == assert_op->outputs[0]) { + op->inputs.erase(it); + changed = true; + } else { + ++it; + } + } + } + CHECK(!CountOpsWithInput(*model, assert_op->outputs[0])); + + if (changed) { + AddMessageF( + "Prepared for the removal of %s by removing any other op's dependency " + "on it", + LogName(*assert_op)); + } + + // That's it. We can stop here, no need to duplicate the work that + // RemoveUnusedOp will do removing this now-unused node. + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc new file mode 100644 index 0000000000..404269bbfd --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { + const auto passthru_it = model->operators.begin() + op_index; + const auto* passthru_op = passthru_it->get(); + if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + return false; + } + + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc new file mode 100644 index 0000000000..6add443f2d --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template +bool AreAllBufferElementsEqualTo(const std::vector& buffer_data, + Scalar value) { + for (auto x : buffer_data) { + if (x != value) { + return false; + } + } + return true; +} +} // namespace + +// A binary operator is called trivial when exactly one of its operands is +// a constant and is such that the binary operation is equivalent to +// the identity operation on its other input. +// For example, an Add operator is trivial if +// one of its operands is constant 0, a Mul operator is trivial +// if one of its operands is constant 1, etc. +bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // This graph transformation is only concerned with the case + // when one input is constant and the other is not constant. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can resolve here. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // Now check if the constant operand makes this binary + // operator trivial. + const auto& constant_input_array = + *model->arrays[binary_op->inputs[index_of_constant_input]]; + // For now, we only handle floats here. + if (constant_input_array.data_type != ArrayDataType::kFloat) { + return false; + } + const auto& constant_input_float_data = + constant_input_array.GetBuffer().data; + bool is_trivial = false; + if (binary_op->type != OperatorType::kAdd) { + is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); + } else if (binary_op->type != OperatorType::kSub) { + is_trivial = index_of_constant_input == 1 && + AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); + } else if (binary_op->type != OperatorType::kMul) { + is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); + } else if (binary_op->type != OperatorType::kDiv) { + is_trivial = index_of_constant_input == 1 && + AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); + } + + if (!is_trivial) { + return false; + } + + // Now we know that this node is trivial, so we can remove it. + AddMessageF("Removing trivial %s", LogName(*binary_op)); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc new file mode 100644 index 0000000000..3ceb93d8ee --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) { + const auto concat_it = model->operators.begin() + op_index; + auto* concat_op = concat_it->get(); + if (concat_op->type != OperatorType::kConcatenation) { + return false; + } + if (concat_op->inputs.size() != 1) { + return false; + } + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc new file mode 100644 index 0000000000..b603735704 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { + // TensorFlow allows Concatenation nodes to have 0-D inputs, + // and they are then treated as empty i.e. omitted from concatenation, + // in violation of the notion that 0-D is equivalent to 1x1x1x1. + // Thus we have to drop these 0-D inputs from Concatenation nodes. + // Sometimes, there will remain only one non-trivial input, and + // the other graph transformation RemoveTrivialConcatenation will then drop + // it. + const auto concat_it = model->operators.begin() + op_index; + auto* concat_op = concat_it->get(); + if (concat_op->type != OperatorType::kConcatenation) { + return false; + } + std::vector trivial_inputs; + std::vector nontrivial_inputs; + for (const string& input : concat_op->inputs) { + const auto& input_array = model->GetArray(input); + const bool is_trivial = + input_array.has_shape() && input_array.shape().dimensions_count() == 0; + if (is_trivial) { + trivial_inputs.push_back(input); + } else { + nontrivial_inputs.push_back(input); + } + } + + if (trivial_inputs.empty()) { + return false; + } + + // Drop trivial inputs. + for (const string& input : trivial_inputs) { + if (CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + concat_op->inputs = nontrivial_inputs; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc new file mode 100644 index 0000000000..a0d1338298 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { +// Reroute all edges involving a given discardable array to another +// array instead. from_array is assumed to be discardable, and consequently +// this only updates operator edges (since discardable arrays only +// appear there, and not e.g. in model flags). +void RerouteEdges(const string& from_array, const string& to_array, + Model* model) { + for (const auto& op : model->operators) { + for (auto& output : op->outputs) { + if (output == from_array) { + output = to_array; + } + } + for (auto& input : op->inputs) { + if (input == from_array) { + input = to_array; + } + } + } +} + +} // end anonymous namespace + +bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, + Model* model, std::size_t op_index) { + const auto passthru_it = model->operators.begin() + op_index; + auto* passthru_op = passthru_it->get(); + CHECK_EQ(passthru_op->outputs.size(), 1); + CHECK_GE(passthru_op->inputs.size(), 1); + int count_nonconstant_input_arrays = 0; + // We call 'main input' the unique nonconstant input array if there is one, + // or else the 0-th input. + int main_input_array_index = 0; + for (int i = 0; i < passthru_op->inputs.size(); i++) { + if (!model->GetArray(passthru_op->inputs[i]).buffer) { + count_nonconstant_input_arrays++; + main_input_array_index = i; + } + } + CHECK_LE(count_nonconstant_input_arrays, 1); + + const string main_input_name = passthru_op->inputs[main_input_array_index]; + const string output_name = passthru_op->outputs[0]; + if (IsDiscardableArray(*model, output_name)) { + transformation->AddMessageF( + "Removing %s, keeping its non-constant input array", + LogName(*passthru_op)); + model->arrays.erase(output_name); + for (const string& input : passthru_op->inputs) { + if (IsDiscardableArray(*model, input) && input != main_input_name && + CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + RerouteEdges(output_name, main_input_name, model); + } else if (IsDiscardableArray(*model, main_input_name)) { + transformation->AddMessageF("Removing %s, keeping its output array", + LogName(*passthru_op)); + for (const string& input : passthru_op->inputs) { + if (IsDiscardableArray(*model, input) && + (input == main_input_name || CountOpsWithInput(*model, input) == 1)) { + model->arrays.erase(input); + } + } + RerouteEdges(main_input_name, output_name, model); + } else { + transformation->AddMessageF( + "Cannot remove %s, neither its nonconstant input nor its output may be " + "discarded", + LogName(*passthru_op)); + return false; + } + + // Remove the pass-through node. + model->operators.erase(passthru_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h new file mode 100644 index 0000000000..b72c85c0e5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +// A "passthrough op" is an op that satisfies the following conditions: +// 1. It has at most one non-constant input (it may have other constant +// inputs). +// 2. It has exactly one output. +// 3. It forwards exactly its single non-constant input to its single output. +// +// Examples include: +// 1. TensorFlow Identity ops. (Have one input). +// 2. TensorFlow Reshape ops when the input and output shapes agree. +// 3. Any binary operator, one of whose two inputs is a constant and is the +// neutral value for that operation. For example, a binary Add operator +// where one of its inputs is a constant array filled with zeros. +// +// A passthrough op is "trivial" and can be removed when it is possible to +// discard either its single non-constant input or output array, rerouting any +// edge involving it to the other of these two arrays. +// +// It is only possible to discard such an array if it is not explicitly +// designated as a global input/output array of the graph, e.g. the model's +// input arrays, output arrays, and any array involved in a RNN back-edge +// specified by the model. +// +// This function does not check that the given operator is a passthrough op: +// that's the responsibility of the caller. +// Given that it is a passthrough op, this function checks whether it is trivial +// and then discards it and returns true, or, if it's not trivial (if neither +// the input nor the output may be discarded), returns false. +bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, + Model* model, std::size_t op_index); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc new file mode 100644 index 0000000000..28f76c9d36 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, + std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->fused_activation_function != FusedActivationFunctionType::kRelu && + op->fused_activation_function != FusedActivationFunctionType::kRelu6) { + return false; + } + const auto& output_array = model->GetArray(op->outputs[0]); + if (!output_array.quantization_params) { + return false; + } + if (output_array.data_type != ArrayDataType::kUint8) { + return false; + } + const auto& quantization_params = output_array.GetQuantizationParams(); + + bool has_nontrivial_min_bound = false; + bool has_nontrivial_max_bound = false; + + if (op->fused_activation_function == FusedActivationFunctionType::kRelu || + op->fused_activation_function == FusedActivationFunctionType::kRelu6) { + double lowest_representable_output = + (0. - quantization_params.zero_point) * quantization_params.scale; + if (lowest_representable_output < 0.) { + has_nontrivial_min_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the lowest representable output value %g" + " less than the clamp min bound.", + lowest_representable_output); + } + } + if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) { + double highest_representable_output = + (255. - quantization_params.zero_point) * quantization_params.scale; + if (highest_representable_output > 6.) { + has_nontrivial_max_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the highest representable output value %g" + " is greater than the clamp max bound.", + highest_representable_output); + } + } + + if (has_nontrivial_min_bound || has_nontrivial_max_bound) { + return false; + } + + op->fused_activation_function = FusedActivationFunctionType::kNone; + AddMessageF( + "Removing trivial quantized activation function on %s" + " because the output quantization parameters imply at least as tight" + " a clamp anyway.", + LogName(*op)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc new file mode 100644 index 0000000000..90f9381ec1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsReshapeTrivial(const Model& model, const Operator& op, + RemoveTrivialReshape* transformation) { + CHECK(op.type == OperatorType::kTensorFlowReshape); + + // One way in which a reshape can be trivial is if its + // output shape is == its input shape + const auto& input_array = model.GetArray(op.inputs[0]); + const auto& output_array = model.GetArray(op.outputs[0]); + if (input_array.has_shape() && output_array.has_shape()) { + if (transformation->treat_expand_dims_as_trivial() && + ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal up to " + "extending " + "by 1's, and we are told to aggressively discard such Reshape ops.", + LogName(op)); + return true; + } + if (input_array.shape().dims() == output_array.shape().dims()) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal", + LogName(op)); + return true; + } + } + + // Another way in which a reshape can be trivial is if its output + // is only consumed by another reshape. + if (CountOpsWithInput(model, op.outputs[0]) == 1) { + const auto* next_op = GetOpWithInput(model, op.outputs[0]); + if (next_op->type == OperatorType::kTensorFlowReshape) { + transformation->AddMessageF( + "%s is trivial because its output is only consumed by another " + "Reshape op", + LogName(op)); + return true; + } + } + + return false; +} + +} // namespace + +bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + if (!IsReshapeTrivial(*model, *reshape_op, this)) { + return false; + } + + AddMessageF("Removing trivial %s", LogName(*reshape_op)); + + CHECK_EQ(reshape_op->inputs.size(), 2); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc new file mode 100644 index 0000000000..1f1f1f6948 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + + // Bail if any output is used, and is not an input_array of + // the model. We allow specifying an arbitrary input_array, + // treating the part of the graph leading up to it as unused. + for (const auto& output : op->outputs) { + CHECK(model->arrays.count(output)); + // If this output is provided as the model's input array, + // then we don't need this operator to produce its contents. + if (IsInputArray(*model, output)) { + continue; + } + // If this output is provided as a RNN's state array, + // then we don't need this operator to produce its contents. + // So far this case has only been encountered with TensorFlow + // Fill ops used to zero-initialize RNN states, which is + // redundant for us as we zero-initialize RNN states anyway. + bool found_output_as_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.state_array()) { + CHECK(op->type == OperatorType::kTensorFlowUnsupported); + CHECK_EQ(static_cast(op) + ->tensorflow_op, + "Fill"); + found_output_as_rnn_state_array = true; + break; + } + } + if (found_output_as_rnn_state_array) { + continue; + } + for (const string& output_array : model->flags.output_arrays()) { + if (output == output_array) { + return false; + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.back_edge_source_array()) { + return false; + } + } + if (CountOpsWithInput(*model, output)) { + return false; + } + } + + if (op->unresolved_outputs) { + AddMessageF("Not discarding %s because it has unresolved outputs.", + LogName(*op)); + return false; + } + + AddMessageF("Discarding %s because none of its outputs is used.", + LogName(*op)); + + // At that point we know that none of the outputs is used, so we will + // definitely remove the node and all its outputs. + + // Remove any input array that is not used by anything else, + // and that is not the output of some other operator. + for (const auto& input : op->inputs) { + if (CountOpsWithInput(*model, input) == 1 && + !GetOpWithOutput(*model, input)) { + model->arrays.erase(input); + } + } + + // Remove the node and its now-unused output arrays. + for (const auto& output : op->outputs) { + // If the output array is the model's input array, don't remove that. + // That's the case when cropping a model at a given --input_array. + if (IsInputArray(*model, output)) { + continue; + } + // Likewise, if the output array is a RNN state array, don't remove that. + bool found_output_as_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.state_array()) { + found_output_as_rnn_state_array = true; + break; + } + } + if (found_output_as_rnn_state_array) { + continue; + } + // Generic case: do delete this output array. + model->arrays.erase(output); + } + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc new file mode 100644 index 0000000000..3eb7fa3896 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { + auto bn_it = model->operators.begin() + op_index; + if (bn_it->get()->type != OperatorType::kBatchNormalization) { + return false; + } + const auto* bn_op = + static_cast(bn_it->get()); + + const auto& mean_array = model->GetArray(bn_op->inputs[1]); + const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); + const auto& offset_array = model->GetArray(bn_op->inputs[3]); + + CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && + IsConstantParameterArray(*model, bn_op->inputs[2]) && + IsConstantParameterArray(*model, bn_op->inputs[3])) + << "Batch normalization resolution requires that mean, multiplier and " + "offset arrays be constant."; + + // We should only have *float* BatchNormalizations... let's guard this + // assumption by CHECK's. + CHECK(mean_array.data_type == ArrayDataType::kFloat); + CHECK(multiplier_array.data_type == ArrayDataType::kFloat); + CHECK(offset_array.data_type == ArrayDataType::kFloat); + + // Create the new Mul, Add operators + auto* mul_op = new MulOperator; + auto* add_op = new AddOperator; + const string mul_name = + AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); + const string add_name = + AvailableArrayName(*model, bn_op->outputs[0] + "_add"); + const string mul_param_name = AvailableArrayName(*model, mul_name + "_param"); + const string add_param_name = AvailableArrayName(*model, add_name + "_param"); + mul_op->inputs = {bn_op->inputs[0], mul_param_name}; + mul_op->outputs = {mul_name}; + add_op->inputs = {mul_name, add_param_name}; + add_op->outputs = {bn_op->outputs[0]}; + AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op), + LogName(*add_op)); + + // Create the intermediate activation array (output of mul, input of add) + auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]); + intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type; + + // Insert the new operators in the graph + auto add_it = model->operators.emplace(bn_it, add_op); + auto mul_it = model->operators.emplace(add_it, mul_op); + // update invalidated iterators. + DCHECK_EQ(mul_it->get(), mul_op); + add_it = mul_it + 1; + DCHECK_EQ(add_it->get(), add_op); + bn_it = add_it + 1; + DCHECK_EQ(bn_it->get(), bn_op); + + // Create the new param arrays + const auto& mean_shape = mean_array.shape(); + const auto& multiplier_shape = multiplier_array.shape(); + const auto& offset_shape = offset_array.shape(); + CHECK(mean_shape.dims() == multiplier_shape.dims()); + CHECK(mean_shape.dims() == offset_shape.dims()); + const auto& param_shape = mean_shape; + const int buffer_size = RequiredBufferSizeForShape(param_shape); + auto& mul_param_array = model->GetOrCreateArray(mul_param_name); + auto& add_param_array = model->GetOrCreateArray(add_param_name); + DropMinMax(model, mul_param_name); + DropMinMax(model, add_param_name); + mul_param_array.copy_shape(param_shape); + add_param_array.copy_shape(param_shape); + mul_param_array.data_type = ArrayDataType::kFloat; + add_param_array.data_type = ArrayDataType::kFloat; + auto& mul_float_data = + mul_param_array.GetMutableBuffer().data; + auto& add_float_data = + add_param_array.GetMutableBuffer().data; + mul_float_data.resize(buffer_size); + add_float_data.resize(buffer_size); + const auto& mean_float_data = + mean_array.GetBuffer().data; + const auto& multiplier_float_data = + multiplier_array.GetBuffer().data; + const auto& offset_float_data = + offset_array.GetBuffer().data; + + CHECK(mul_float_data.size() == buffer_size); + CHECK(add_float_data.size() == buffer_size); + CHECK(mean_float_data.size() == buffer_size); + CHECK(multiplier_float_data.size() == buffer_size); + CHECK(offset_float_data.size() == buffer_size); + + for (int i = 0; i < buffer_size; i++) { + mul_float_data[i] = multiplier_float_data[i]; + add_float_data[i] = + offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i]; + } + + // Remove the old param arrays + model->arrays.erase(bn_op->inputs[1]); + model->arrays.erase(bn_op->inputs[2]); + model->arrays.erase(bn_op->inputs[3]); + + // Remove the old operator + DCHECK_EQ(bn_it->get(), bn_op); + model->operators.erase(bn_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc new file mode 100644 index 0000000000..53e1be7a05 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector VectorGreaterThan(const std::vector& a, + const std::vector& b) { + DCHECK_EQ(a.size(), b.size()); + const int size = a.size(); + std::vector result(size); + for (int i = 0; i < size; i++) { + result[i] = a[i] > b[i]; + } + return result; +} + +void PairwiseVectorSelect(const std::vector& selector, + const std::vector& input_a, + const std::vector& input_b, + std::vector* output_a, + std::vector* output_b) { + DCHECK_EQ(input_a.size(), input_b.size()); + DCHECK_EQ(output_a->size(), output_b->size()); + DCHECK_EQ(input_a.size(), output_a->size()); + DCHECK_EQ(selector.size(), input_a.size()); + const int size = input_a.size(); + for (int i = 0; i < size; i++) { + if (selector[i]) { + (*output_a)[i] = input_a[i]; + (*output_b)[i] = input_b[i]; + } else { + (*output_a)[i] = input_b[i]; + (*output_b)[i] = input_a[i]; + } + } +} + +template +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + CHECK(IsConstantParameterArray(*model, binary_op->inputs[0])); + CHECK(IsConstantParameterArray(*model, binary_op->inputs[1])); + CHECK(binary_op->fused_activation_function == + FusedActivationFunctionType::kNone); + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + const auto& output_name = binary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + CHECK(input0_array.data_type == InputsDataType); + CHECK(input1_array.data_type == InputsDataType); + CHECK(output_array.data_type == OutputDataType); + + // We have already tested above for existence of input buffers + // (synonymous to being a constant param). + CHECK(input0_array.buffer); + CHECK(input1_array.buffer); + // On the other hand, the output should not already have a buffer. + CHECK(!output_array.buffer); + + const auto& input0_data = input0_array.GetBuffer().data; + const auto& input1_data = input1_array.GetBuffer().data; + // Create the buffer on the output array, effectively turning it into + // a constant parameter + + const Shape& output_shape = output_array.shape(); + auto& output_data = output_array.GetMutableBuffer().data; + const int output_buffer_size = RequiredBufferSizeForShape(output_shape); + output_data.resize(output_buffer_size); + const int dims_count = output_shape.dimensions_count(); + + // It will be convenient here to have copies of the operands shapes + // extended to match the number of dimensions of the output shape. + Shape input0_shape = input0_array.shape(); + Shape input1_shape = input1_array.shape(); + ExtendShape(&input0_shape, dims_count); + ExtendShape(&input1_shape, dims_count); + // Now we may still have operands of different sizes, which would indicate + // that we have to "broadcast" the smaller dimension. We do this using a + // a vector of Booleans indicating which input is the larger in each + // dimension. + CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count()); + CHECK_EQ(input0_shape.dimensions_count(), dims_count); + const std::vector input0_larger = + VectorGreaterThan(input0_shape.dims(), input1_shape.dims()); + + std::vector big_sizes(dims_count); + std::vector small_sizes(dims_count); + PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(), + &big_sizes, &small_sizes); + + // The output should already be correctly sized to match the big dimensions. + for (int i = 0; i < dims_count; i++) { + CHECK_EQ(output_shape.dims(i), big_sizes[i]); + } + + std::vector input0_indices(dims_count); + std::vector input1_indices(dims_count); + std::vector modulo_indices(dims_count); + + for (int k = 0; k < output_buffer_size; k++) { + const std::vector output_indices = ReverseOffset(output_shape, k); + for (int i = 0; i < dims_count; i++) { + modulo_indices[i] = output_indices[i] % small_sizes[i]; + } + PairwiseVectorSelect(input0_larger, output_indices, modulo_indices, + &input0_indices, &input1_indices); + const auto val0 = input0_data[Offset(input0_shape, input0_indices)]; + const auto val1 = input1_data[Offset(input1_shape, input1_indices)]; + + DataType outval; + if (binary_op->type == OperatorType::kAdd) { + outval = val0 + val1; + } else if (binary_op->type == OperatorType::kMul) { + outval = val0 * val1; + } else if (binary_op->type == OperatorType::kSub) { + outval = val0 - val1; + } else if (binary_op->type == OperatorType::kDiv) { + outval = val0 / val1; + } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + outval = std::min(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + outval = std::max(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowLess) { + outval = val0 < val1; + } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + outval = val0 <= val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + outval = val0 > val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + outval = val0 >= val1; + } else { + LOG(FATAL) << "should not get here"; + } + output_data[Offset(output_shape, output_indices)] = outval; + } +} + +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type; + const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type; +#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ + if (inputs_data_type == InputsDataType && \ + output_data_type == OutputDataType) { \ + EvaluateBinaryOperatorOnConstantInputs( \ + model, binary_op); \ + return; \ + } + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat) + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool) + LOG(FATAL) << "Unimplemented: don't know how to resolve a constant " + << "binary operator for these data types."; +#undef TOCO_HANDLE_CASE +} +} // namespace + +bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kTensorFlowMinimum && + binary_op->type != OperatorType::kTensorFlowMaximum && + binary_op->type != OperatorType::kTensorFlowLess && + binary_op->type != OperatorType::kTensorFlowLessEqual && + binary_op->type != OperatorType::kTensorFlowGreater && + binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + return false; + } + CHECK_EQ(binary_op->inputs.size(), 2); + + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + // Check if both inputs are constant parameters. + if (!input0_array.buffer || !input1_array.buffer) { + return false; + } + + auto& output_array = *model->arrays[binary_op->outputs[0]]; + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + // Check that input data types agree. + CHECK(input0_array.data_type == input1_array.data_type); + + // Do the actual constants propagation + EvaluateBinaryOperatorOnConstantInputs(model, binary_op); + + // Remove the binary operator and its inputs + if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { + model->arrays.erase(binary_op->inputs[0]); + } + if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { + model->arrays.erase(binary_op->inputs[1]); + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*binary_op)); + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc new file mode 100644 index 0000000000..0983c43849 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -0,0 +1,196 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Copies data from multiple source arrays to a destination array based on a +// concatenation dimension. From each array in input_arrays, it copies chunk +// sizes provided in array_copy_size vector (per array). It uses the buffer +// in concatenated_array as destination buffer. +template +void CopyTensorSegments(const std::vector& input_arrays, + const std::vector& array_copy_size, + const int num_elements_concatenated_array, + Array* concatenated_array) { + for (Array* input_array : input_arrays) { + if (!input_array->buffer) { + return; + } + } + + auto& concatenated_array_buffer = + concatenated_array->GetMutableBuffer().data; + concatenated_array_buffer.resize(num_elements_concatenated_array); + + // It does not matter which array to use to find the value for the total + // number of copy steps. + CHECK(!input_arrays.empty()); + CHECK_NE(array_copy_size[0], 0); + const int total_copy_steps = + input_arrays[0]->GetBuffer().data.size() / array_copy_size[0]; + + // Initialize the source pointers to point to beginning of the array buffers. + std::vector src_ptr; + src_ptr.reserve(input_arrays.size()); + for (Array* input_array : input_arrays) { + src_ptr.push_back(input_array->GetBuffer().data.data()); + } + + // Copy the data from input_arrays to concatenated_array_buffer. + T* dest_ptr = concatenated_array_buffer.data(); + for (int s = 0; s < total_copy_steps; s++) { + for (int i = 0; i < input_arrays.size(); i++) { + std::copy(src_ptr[i], src_ptr[i] + array_copy_size[i], dest_ptr); + src_ptr[i] += array_copy_size[i]; + dest_ptr += array_copy_size[i]; + } + } +} + +// Receives a series of input arrays of type Array and an integer showing the +// axis on which those arrays will be concatenated. It returns the concatenated +// arrray. +template +void ConcatenateTensorBuffers(const std::vector& input_arrays, + int concatenation_axis, + Array* concatenated_array) { + int num_elements_concatenated_array = 1; + for (int i = 0; i < concatenated_array->shape().dimensions_count(); i++) { + num_elements_concatenated_array *= concatenated_array->shape().dims()[i]; + } + // Prepare the data needed for segmented copy from multiple source arrays to + // a destination array based on a oncatenation dimension. + std::vector array_copy_size(input_arrays.size()); + int count = 0; + for (Array* input_array : input_arrays) { + const Shape array_shape = input_array->shape(); + array_copy_size[count] = 1; + for (int i = concatenation_axis; i < array_shape.dimensions_count(); i++) { + array_copy_size[count] *= array_shape.dims()[i]; + } + count++; + } + + // Do the actual data copy. + CopyTensorSegments>(input_arrays, array_copy_size, + num_elements_concatenated_array, + concatenated_array); +} + +// Sets the minimum and maximum values for the concatenated array. If it's +// already set (e.g. because of previous pass in TOCO), it doesn't change it and +// returns. Otherwise it uses the input arrays min and max values to compute the +// concatenated array min and max. +void SetMinMaxForConcatenedArray(const std::vector& input_arrays, + Array* concatenated_array) { + CHECK(concatenated_array->data_type == ArrayDataType::kFloat); + // If the minmax is already set, use it + if (concatenated_array->minmax) return; + + double concat_min = std::numeric_limits::infinity(); + double concat_max = -std::numeric_limits::infinity(); + + for (Array* input_array : input_arrays) { + // If any of the input arrays minmax is not set, return. + // TODO(ghodrat): shall we add the logic to compute the minmax? + if (!input_array->minmax) return; + const MinMax& input_minmax = input_array->GetMinMax(); + concat_min = std::min(concat_min, input_minmax.min); + concat_max = std::max(concat_max, input_minmax.max); + } + MinMax& minmax = concatenated_array->GetOrCreateMinMax(); + minmax.min = concat_min; + minmax.max = concat_max; +} + +} // namespace + +// Resolves the concatenation operator if all its inputs are constant arrays. +bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { + const auto concat_it = model->operators.begin() + op_index; + const auto* concat_base_op = concat_it->get(); + if (concat_base_op->type != OperatorType::kConcatenation) { + return false; + } + const auto* concat_op = + static_cast(concat_base_op); + + for (const string& input_name : concat_op->inputs) { + // We only expect constant unquantized arrays as input, otherwise we return. + // We also make sure the shapes of the input arrays are known and they are + // all discardable. + const Operator* input_op = GetOpWithOutput(*model, input_name); + if (input_op) return false; + if (!IsConstantParameterArray(*model, input_name)) return false; + if (!model->GetArray(input_name).has_shape()) return false; + if (model->GetArray(input_name).quantization_params) return false; + if (!IsDiscardableArray(*model, input_name)) return false; + } + + const int concatenation_axis = concat_op->concat_dim; + + CHECK_EQ(concat_op->outputs.size(), 1); + string concatenated_array_name = concat_op->outputs[0]; + Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name); + std::vector input_arrays; + for (const string& input_name : concat_op->inputs) { + input_arrays.push_back(&model->GetArray(input_name)); + } + + switch (concatenated_array.data_type) { + case ArrayDataType::kFloat: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + SetMinMaxForConcatenedArray(input_arrays, &concatenated_array); + break; + case ArrayDataType::kUint8: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; + case ArrayDataType::kInt32: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; + case ArrayDataType::kInt64: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; + default: + LOG(FATAL) << "ArrayDataType not supported"; + } + + // Remove all the resolved arrays. + for (const string& input_name : concat_op->inputs) { + model->arrays.erase(input_name); + } + + // Remove concatenate operator + model->operators.erase(concat_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc new file mode 100644 index 0000000000..244adcc4c4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + const auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + + const auto* fakequant_op = + static_cast(fakequant_base_op); + + // Yield until the fakequant MinMax has been resolved. + if (!fakequant_op->minmax) { + return false; + } + + // This transformation only applies when the input array is constant. + if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) { + return false; + } + + const auto& input_array = model->GetArray(fakequant_op->inputs[0]); + auto& output_array = model->GetArray(fakequant_op->outputs[0]); + CHECK(input_array.data_type == ArrayDataType::kFloat); + output_array.data_type = ArrayDataType::kFloat; + CHECK(!output_array.buffer); + const auto& input_buffer = input_array.GetBuffer(); + auto& output_buffer = output_array.GetMutableBuffer(); + const int size = input_buffer.data.size(); + output_buffer.data.resize(size); + QuantizationParams qparams; + GetQuantizationParamsFromMinMax( + model->flags, *fakequant_op->minmax, &qparams); + for (int i = 0; i < size; i++) { + const double src_val = input_buffer.data[i]; + const double unclamped_quantized_val = + std::round(qparams.zero_point + src_val / qparams.scale); + const double quantized_val = + std::min(255., std::max(0., unclamped_quantized_val)); + const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); + output_buffer.data[i] = dst_val; + } + if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { + model->arrays.erase(fakequant_op->inputs[0]); + } + model->operators.erase(fakequant_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc new file mode 100644 index 0000000000..8cc6db1619 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) { + const auto tfshape_it = model->operators.begin() + op_index; + const auto* tfshape_base_op = tfshape_it->get(); + if (tfshape_base_op->type != OperatorType::kTensorFlowShape) { + return false; + } + + const auto* tfshape_op = + static_cast(tfshape_base_op); + + const auto& input_array = model->GetArray(tfshape_op->inputs[0]); + auto& output_array = model->GetArray(tfshape_op->outputs[0]); + + // Yield until the input array's shape has been resolved. + if (!input_array.has_shape()) { + return false; + } + + // Create a buffer for the output array, making it a constant array, and + // copy the input shape into the output buffer. + CHECK(!output_array.buffer); + auto& output_buffer = output_array.GetMutableBuffer(); + output_buffer.data = input_array.shape().dims(); + + // Erase the input array if no longer used + if (IsDiscardableArray(*model, tfshape_op->inputs[0]) && + CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) { + model->arrays.erase(tfshape_op->inputs[0]); + } + model->operators.erase(tfshape_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc new file mode 100644 index 0000000000..bb9bda3c82 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { + const auto unary_it = model->operators.begin() + op_index; + const auto* unary_op = unary_it->get(); + // Test for unary ops of types that we know how to resolve + if (unary_op->type != OperatorType::kTensorFlowRsqrt && + unary_op->type != OperatorType::kTensorFlowSqrt && + unary_op->type != OperatorType::kTensorFlowSquare && + unary_op->type != OperatorType::kTensorFlowSum && + unary_op->type != OperatorType::kTensorFlowMin && + unary_op->type != OperatorType::kTensorFlowMax && + unary_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + // Check if the input is a constant parameter. + if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { + return false; + } + + // if the unary op involves a tensor required by a rnn state, ignore it + for (const auto& rnn_state : model->flags.rnn_states()) { + if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) { + return false; + } + if (unary_op->inputs[0] == rnn_state.state_array()) { + return false; + } + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (unary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s " + " because it has a fused activation function", + LogName(*unary_op)); + return false; + } + const auto& input_array = model->GetArray(unary_op->inputs[0]); + // We have already tested above for existence of buffers (synonymous to being + // a constant param). + CHECK(input_array.buffer); + // At the moment we only support float buffers. + if (input_array.buffer->type != ArrayDataType::kFloat) { + return false; + } + const auto& input_float_data = + input_array.GetBuffer().data; + // Create the float buffer on the output array, effectively turning it into + // a constant parameter + const auto& output_name = unary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + int input_buffer_size = RequiredBufferSizeForShape(input_array.shape()); + int output_buffer_size = RequiredBufferSizeForShape(output_array.shape()); + const Shape& input_shape = input_array.shape(); + const Shape& output_shape = output_array.shape(); + + auto& output_float_data = + output_array.GetMutableBuffer().data; + output_float_data.resize(output_buffer_size); + + const int output_dims_count = output_shape.dimensions_count(); + if (unary_op->type == OperatorType::kTensorFlowReshape) { + CHECK(input_buffer_size == output_buffer_size); + memcpy(output_float_data.data(), input_float_data.data(), + input_buffer_size * sizeof(input_float_data[0])); + } else if (unary_op->type == OperatorType::kTensorFlowSum) { + // At the moment only full reduction across all dimensions is supported. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float sum = 0.f; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + sum += input_float_data[i]; + } + output_float_data[0] = sum; + } else if (unary_op->type == OperatorType::kTensorFlowMin) { + // At the moment only full reduction across all dimensions is supported. + // TODO(starka): Output should not be padded. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float min = input_float_data[0]; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + min = std::min(min, input_float_data[i]); + } + output_float_data[0] = min; + } else if (unary_op->type == OperatorType::kTensorFlowMax) { + // At the moment only full reduction across all dimensions is supported. + // TODO(starka): Output should not be padded. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float max = input_float_data[0]; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + max = std::max(max, input_float_data[i]); + } + output_float_data[0] = max; + } else if (unary_op->type == OperatorType::kTensorFlowRsqrt || + unary_op->type == OperatorType::kTensorFlowSqrt || + unary_op->type == OperatorType::kTensorFlowSquare) { + // Element-wise ops. Should have perfectly matching sizes here. + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); + } + + for (int i = 0; i < input_size; i++) { + const float val = input_float_data[i]; + float outval = 0.f; + if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + outval = 1.0f / std::sqrt(val); + } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + outval = std::sqrt(val); + } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + outval = val * val; + } else { + LOG(FATAL) << "should not get here."; + } + output_float_data[i] = outval; + } + } else { + LOG(FATAL) << "should not get here."; + } + for (const auto& input : unary_op->inputs) { + if (CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*unary_op)); + model->operators.erase(unary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc new file mode 100644 index 0000000000..d25c773f19 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { + auto* mean_op = model->operators[op_index].get(); + if (mean_op->type != OperatorType::kMean) return false; + auto* op = static_cast(mean_op); + + if (!op->reduction_indices.empty()) return false; + if (op->inputs.size() != 2) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& indices_array = *model->arrays[op->inputs[1]]; + if (!indices_array.has_shape()) return false; + + op->reduction_indices = indices_array.GetBuffer().data; + + // At the moment, we only support simultaneous reduction over width and + // height. This is mainly limited by the fact that currently, the runtime + // arrays are always 4-dimensional. + CHECK_EQ(op->reduction_indices.size(), 2); + CHECK((op->reduction_indices[0] == 1 && op->reduction_indices[1] == 2) || + (op->reduction_indices[0] == 2 && op->reduction_indices[1] == 1)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc new file mode 100644 index 0000000000..d5f5869c62 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { + const auto pad_it = model->operators.begin() + op_index; + auto* pad_op = pad_it->get(); + if (pad_op->type != OperatorType::kPad) return false; + + auto* op = static_cast(pad_op); + if (!op->left_padding.empty()) return false; + + CHECK_EQ(op->inputs.size(), 2); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& array = *model->arrays[op->inputs[1]]; + if (!array.has_shape()) return false; + + const std::vector& dims = array.shape().dims(); + CHECK_EQ(dims.size(), 2); + + std::vector buffer = array.GetBuffer().data; + + for (int i = 0; i < dims[0]; ++i) { + op->left_padding.push_back(buffer[i * 2]); + op->right_padding.push_back(buffer[i * 2 + 1]); + } + + // TODO(dkalenichenko): Delete the extra input? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc new file mode 100644 index 0000000000..8fa7b83bed --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { + auto reorder_it = model->operators.begin() + op_index; + auto* reorder_op = static_cast(reorder_it->get()); + if (reorder_op->type != OperatorType::kReorderAxes) { + return false; + } + const auto& input_array_name = reorder_op->inputs[0]; + const auto& output_array_name = reorder_op->outputs[0]; + auto& input_array = model->GetArray(input_array_name); + auto& output_array = model->GetArray(output_array_name); + string constant_input_array_name = input_array_name; + if (!input_array.buffer) { + const auto* op_producing_input = GetOpWithOutput(*model, input_array_name); + if (op_producing_input && + op_producing_input->type == OperatorType::kFakeQuant) { + constant_input_array_name = op_producing_input->inputs[0]; + } + } + auto& constant_input_array = model->GetArray(constant_input_array_name); + if (!constant_input_array.buffer) { + return false; + } + // Yield until output dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + // Reorder the input array dims and buffer data + CHECK(constant_input_array.buffer->type == ArrayDataType::kFloat); + CHECK(!output_array.buffer); + auto& input_data = + constant_input_array.GetMutableBuffer().data; + std::vector reordered_data; + reordered_data.resize(RequiredBufferSizeForShape(output_array.shape())); + const auto input_axes_order = reorder_op->input_axes_order; + const auto output_axes_order = reorder_op->output_axes_order; + // TODO(b/62904716) Shapes should be used directly. + Shape input_shape = constant_input_array.shape(); + Shape output_shape = output_array.shape(); + if (AxesCount(input_axes_order) == 2) { + UnextendShape(&input_shape, 2); + UnextendShape(&output_shape, 2); + } + ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape, + input_data.data(), reordered_data.data()); + input_data = reordered_data; + input_array.copy_shape(output_array.shape()); + constant_input_array.copy_shape(output_array.shape()); + + // Update the edges of the graph to point to the input array + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == output_array_name) { + input = input_array_name; + } + } + } + + AddMessageF("Reordered axes for array %s", input_array_name); + + // Remove the op and output array. + model->arrays.erase(output_array_name); + model->operators.erase(reorder_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc new file mode 100644 index 0000000000..bed2a85bd2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + auto* op = static_cast(reshape_op); + + if (!op->shape.empty()) return false; + + if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { + const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]]; + op->shape = constant_input_array.GetBuffer().data; + } + + if (op->shape.empty()) return false; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc new file mode 100644 index 0000000000..1d0a2ec8f6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) { + const auto slice_it = model->operators.begin() + op_index; + auto* slice_op = slice_it->get(); + if (slice_op->type != OperatorType::kSlice) return false; + + auto* op = static_cast(slice_op); + if (!op->begin.empty()) return false; + + CHECK_EQ(op->inputs.size(), 3); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + + const auto& begin_array = *model->arrays[op->inputs[1]]; + if (!begin_array.has_shape()) return false; + + const auto& size_array = *model->arrays[op->inputs[2]]; + if (!size_array.has_shape()) return false; + + op->begin = begin_array.GetBuffer().data; + op->size = size_array.GetBuffer().data; + + // TODO(dkalenichenko): Delete the extra inputs? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc new file mode 100644 index 0000000000..5fc3b25bc1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { + const auto slice_it = model->operators.begin() + op_index; + auto* slice_op = slice_it->get(); + if (slice_op->type != OperatorType::kStridedSlice) return false; + + auto* op = static_cast(slice_op); + if (!op->start_indices.empty()) return false; + + CHECK_EQ(op->inputs.size(), 4); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (!IsConstantParameterArray(*model, op->inputs[3])) return false; + + const auto& start_array = *model->arrays[op->inputs[1]]; + if (!start_array.has_shape()) return false; + + const auto& stop_array = *model->arrays[op->inputs[2]]; + if (!stop_array.has_shape()) return false; + + const auto& stride_array = *model->arrays[op->inputs[3]]; + if (!stride_array.has_shape()) return false; + + op->start_indices = start_array.GetBuffer().data; + op->stop_indices = stop_array.GetBuffer().data; + op->strides = stride_array.GetBuffer().data; + + // Only 4D arrays are supported for now. + CHECK_EQ(op->start_indices.size(), 4); + CHECK_EQ(op->stop_indices.size(), 4); + CHECK_EQ(op->strides.size(), 4); + + // TODO(dkalenichenko): Delete the extra inputs? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc new file mode 100644 index 0000000000..b482f5cf51 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { + auto concat_it = model->operators.begin() + op_index; + const auto* tf_concat_op = concat_it->get(); + if (tf_concat_op->type != OperatorType::kTensorFlowConcat && + tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + return false; + } + + CHECK_GE(tf_concat_op->inputs.size(), 2); + // TensorFlow Concat and ConcatV2 nodes only differ by the ordering + // of inputs: in Concat, the concat_dim is the first input, while in + // ConcatV2, it is the last input. + std::size_t concat_dim_pos = 0; + if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + concat_dim_pos = tf_concat_op->inputs.size() - 1; + } + const string concat_dim_name = tf_concat_op->inputs[concat_dim_pos]; + std::vector concat_input_names; + for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) { + if (i != concat_dim_pos) { + concat_input_names.push_back(tf_concat_op->inputs[i]); + } + } + // If the concat_dim array hasn't been resolved to a constant yet, + // we need to yield. + const auto& concat_dim_array = model->GetArray(concat_dim_name); + if (!concat_dim_array.buffer) { + AddMessageF("Waiting for the concat_dim of %s to be resolved to a constant", + LogName(*tf_concat_op)); + return false; + } + + CHECK(concat_dim_array.data_type == ArrayDataType::kInt32); + const auto& concat_dim_data = + concat_dim_array.GetBuffer().data; + CHECK_EQ(concat_dim_data.size(), 1); + const int concat_dim = concat_dim_data[0]; + + // Create the Concatenation op replacing the TensorFlowConcat op. + auto* concatenation_op = new ConcatenationOperator; + concatenation_op->concat_dim = concat_dim; + concatenation_op->inputs = concat_input_names; + concatenation_op->outputs = {tf_concat_op->outputs[0]}; + auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op); + CHECK_EQ(depth_concat_it->get(), concatenation_op); + // Update invalidated iterator + concat_it = depth_concat_it + 1; + CHECK_EQ(concat_it->get(), tf_concat_op); + + // Remove the concat_dim array if it is not used by anything else. + if (CountOpsWithInput(*model, concat_dim_name) == 1) { + model->arrays.erase(concat_dim_name); + } + // Remove the TensorFlowConcat op + model->operators.erase(concat_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc new file mode 100644 index 0000000000..bea7487051 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { + auto matmul_it = model->operators.begin() + op_index; + if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + return false; + } + const auto* matmul_op = matmul_it->get(); + + // Find the op producing the array passed to this MatMul + auto previous_op_it = model->operators.begin(); + bool found = false; + for (; previous_op_it != model->operators.end(); ++previous_op_it) { + for (const auto& output : (*previous_op_it)->outputs) { + if (output == matmul_op->inputs[0]) { + found = true; + break; + } + } + if (found) { + break; + } + } + Operator* previous_op = (found) ? previous_op_it->get() : nullptr; + + // construct the new FullyConnectedOperator + auto* fc_op = new FullyConnectedOperator; + fc_op->outputs = matmul_op->outputs; + + // insert the newly constructed FullyConnectedOperator + auto fc_it = model->operators.emplace(matmul_it, fc_op); + + // refresh invalidated iterator + matmul_it = fc_it + 1; + DCHECK_EQ(matmul_it->get(), matmul_op); + + // The way that TensorFlow encodes FullyConnected ops is as a pair + // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the + // MatMul + // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the + // input doesn't need reshaping, so we can't just match (Reshape, MatMul) + // pairs. + if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + AddMessageF("Combining %s and %s into %s", LogName(*previous_op), + LogName(*matmul_op), LogName(*fc_op)); + const auto& previous_op_output = previous_op->outputs[0]; + if (CountOpsWithInput(*model, previous_op_output) == 1) { + model->arrays.erase(previous_op_output); + } + CHECK_EQ(previous_op->inputs.size(), 2); + fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; + // Only remove Reshape node if no other node uses its output. + if (CountOpsWithInput(*model, previous_op_output) == 1) { + const auto& previous_op_shape = previous_op->inputs[1]; + if (CountOpsWithInput(*model, previous_op_shape) == 1 && + !GetOpWithOutput(*model, previous_op_shape)) { + model->arrays.erase(previous_op_shape); + } + model->operators.erase(previous_op_it); + } + + // We may have just invalidated matmul_it, so let's refresh it now. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } + CHECK(matmul_it != model->operators.end()); + CHECK(matmul_it->get() == matmul_op); + } else { + AddMessageF("Replacing %s by a FullyConnected operator", + LogName(*matmul_op)); + fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]}; + } + + // erase the MatMul operator + model->operators.erase(matmul_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc new file mode 100644 index 0000000000..cfa5ce0716 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { + const auto merge_it = model->operators.begin() + op_index; + const auto* merge_op = merge_it->get(); + if (merge_op->type != OperatorType::kTensorFlowMerge) { + return false; + } + + // We need to yield until this Merge node has only 1 input, which will mean + // that that is the selected input. Other graph transformations on other nodes + // such as ResolveTensorFlowSwitch, will take care of trimming the + // non-selected inputs, so that at some point there will be only 1 input left. + if (merge_op->inputs.size() > 1) { + AddMessageF("Waiting for %s to be resolved", LogName(*merge_op)); + return false; + } + + // Now that the merge node has 1 input exactly, it is the same as an Identity + // node and can be resolved trivially. + CHECK_EQ(merge_op->inputs.size(), 1); + + // Update the edges of the graph ahead of removing the node. + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == merge_op->outputs[0]) { + input = merge_op->inputs[0]; + } + } + } + + // Remove the node and its output array. + AddMessageF("Removing already-resolved %s", LogName(*merge_op)); + model->arrays.erase(merge_op->outputs[0]); + model->operators.erase(merge_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc new file mode 100644 index 0000000000..1d3f42b5ec --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) { + const auto squeeze_it = model->operators.begin() + op_index; + const auto* squeeze_op = squeeze_it->get(); + if (squeeze_op->type != OperatorType::kSqueeze) { + return false; + } + + CHECK_EQ(squeeze_op->inputs.size(), 1); + CHECK_EQ(squeeze_op->outputs.size(), 1); + + // If the output is consumed by a reshape op, it's a trivial squeeze. + if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { + const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); + if (next_op->type == OperatorType::kTensorFlowReshape) { + AddMessageF( + "%s is trivial because its output is only consumed by a " + "Reshape op", + LogName(*squeeze_op)); + + return RemoveTrivialPassthroughOp(this, model, op_index); + } + } + + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc new file mode 100644 index 0000000000..55adfca037 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { + const auto switch_it = model->operators.begin() + op_index; + const auto* switch_op = switch_it->get(); + if (switch_op->type != OperatorType::kTensorFlowSwitch) { + return false; + } + + CHECK_EQ(switch_op->inputs.size(), 2); + CHECK_EQ(switch_op->outputs.size(), 2); + const string& predicate_name = switch_op->inputs[1]; + // If the predicate array hasn't been resolved to a constant yet, + // we need to yield. + if (!IsConstantParameterArray(*model, predicate_name)) { + AddMessageF( + "Waiting for the boolean predicate of %s to be resolved to a constant", + LogName(*switch_op)); + return false; + } + + // The predicate should be boolean, and should consist of a single value. + const auto& predicate_array = model->GetArray(predicate_name); + CHECK(predicate_array.data_type == ArrayDataType::kBool); + for (const auto& dim : predicate_array.shape().dims()) { + CHECK_EQ(dim, 1); + } + + // Obtain the predicate boolean value. + const auto& predicate_data = + predicate_array.GetBuffer().data; + CHECK_EQ(predicate_data.size(), 1); + const bool predicate_value = predicate_data[0]; + + // From the TensorFlow docs on .switch() in + // third_party/tensorflow/python/ops/control_flow_ops.py + // + // If `pred` is false, the `data` input is forwared to the first output. + // Otherwise, the data goes to the second output. + // + // Note that this comment used to say the opposite and was recently fixed: + // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c + const int selected_output_index = predicate_value ? 1 : 0; + const int nonselected_output_index = predicate_value ? 0 : 1; + + // Update the edges of the graph ahead of removing the node: + // edges that were pointing to the selected output, should instead + // point to the input of the Switch node. + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == switch_op->outputs[selected_output_index]) { + input = switch_op->inputs[0]; + } + } + } + + // There remains to handle the edges that were pointing to the nonselected + // output. We will just discard those edges. Concretely, at the moment, + // our only examples of graphs with Switch nodes have them feeding into Merge + // nodes, so what we're saying here is that we'll make the convention, + // in our toco internal representation, that Merge nodes with only 1 input + // are Merge nodes that have been resolved already and should be have as + // Identity nodes, simply forwarding their input. + // + for (const auto& other_op : model->operators) { + auto input_it = other_op->inputs.begin(); + while (input_it != other_op->inputs.end()) { + if (*input_it == switch_op->outputs[nonselected_output_index]) { + // Let us guard our assumption that only Merge nodes consume the outputs + // of Switch nodes: + CHECK(other_op->type == OperatorType::kTensorFlowMerge); + input_it = other_op->inputs.erase(input_it); + } else { + ++input_it; + } + } + } + + // Remove the output arrays if they are now unused. + for (int i = 0; i < 2; i++) { + if (!GetOpWithInput(*model, switch_op->outputs[i])) { + model->arrays.erase(switch_op->outputs[i]); + } + } + // Remove input arrays if they are only used by the switch itself and aren't + // the output of another op (will get handled by RemoveUnusedOp in that case). + for (const auto& input : switch_op->inputs) { + if (CountOpsWithInput(*model, input) == 1 && + !GetOpWithOutput(*model, input)) { + model->arrays.erase(input); + } + } + // Remove the switch node itself. + AddMessageF("Removing already-resolved %s", LogName(*switch_op)); + model->operators.erase(switch_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc new file mode 100644 index 0000000000..9f7e7c42a2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, + int operand_index) { + CHECK(tile_op->type == OperatorType::kTensorFlowTile); + CHECK_EQ(binary_op->inputs.size(), 2); + CHECK_EQ(tile_op->inputs.size(), 2); + const string tile_multiplier_array = tile_op->inputs[1]; + const string tile_output_array = tile_op->outputs[0]; + binary_op->inputs[operand_index] = tile_op->inputs[0]; + auto tile_it = model->operators.begin(); + for (; tile_it != model->operators.end(); ++tile_it) { + if (tile_it->get() == tile_op) { + break; + } + } + CHECK(tile_it != model->operators.end()); + CHECK(tile_it->get() == tile_op); + model->operators.erase(tile_it); + if (!CountOpsWithInput(*model, tile_multiplier_array) && + !GetOpWithOutput(*model, tile_multiplier_array)) { + model->arrays.erase(tile_multiplier_array); + } + if (!CountOpsWithInput(*model, tile_output_array)) { + model->arrays.erase(tile_output_array); + } +} +} // namespace + +bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // In the unlikely case where both operands are Tile, we can't infer the + // output + // size without the Tile nodes, so we have to bail out. + if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && + op[1]->type == OperatorType::kTensorFlowTile) { + return false; + } + + for (int i = 0; i < 2; i++) { + if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { + // We can only remove a Tile operator is no other op than the present + // binary op was consuming its tiled output. + if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { + AddMessageF("Removing %s", LogName(*op[i])); + RemoveTileOperator(model, op[i], binary_op, i); + return true; + } + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD new file mode 100644 index 0000000000..8931498782 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc new file mode 100644 index 0000000000..c6705ad305 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include +#include +//#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(testing::FloatNear(v, max_abs_error)); + } + return matchers; +} +} // namespace + +// The following 3 tests make sure the concatenation operation on different axis +// values match TensorFlow results listed below: +// +// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] +// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]] +// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]] +// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]] +// +// ConcatAtAxis0 test: +// t0 = tf.concat([x0, x1, x2, x3], 0) +// [[[ 0 1] +// [ 2 3]] +// +// [[ 4 5] +// [ 6 7]] +// +// [[10 11] +// [12 13]] +// +// [[14 15] +// [16 17]] +// +// [[20 21] +// [22 23]] +// +// [[24 25] +// [26 27]] +// +// [[30 31] +// [32 33]] +// +// [[34 35] +// [36 37]]] +// +// ConcatAtAxis1 test: +// t1 = tf.concat([x0, x1, x2, x3], 1) +// [[[ 0 1] +// [ 2 3] +// [10 11] +// [12 13] +// [20 21] +// [22 23] +// [30 31] +// [32 33]] +// +// [[ 4 5] +// [ 6 7] +// [14 15] +// [16 17] +// [24 25] +// [26 27] +// [34 35] +// [36 37]]] +// +// ConcatAtAxis2 test: +// t2 = tf.concat([x0, x1, x2, x3], 2) +// [[[ 0 1 10 11 20 21 30 31] +// [ 2 3 12 13 22 23 32 33]] +// +// [[ 4 5 14 15 24 25 34 35] +// [ 6 7 16 17 26 27 36 37]]] + +class ResolveConstantConcatenationTest : public ::testing::Test { + protected: + ResolveConstantConcatenationTest() {} + + // Prepare a hypothetical TOCO model with one Concatenation operator in it + // together with 4 arrays as its inputs. + // It receives the dimension of concatenation as input. + void PrepareModel(Model* model, int concat_dim) { + std::vector concat_input_names = {"array0", "array1", "array2", + "array3"}; + + const int kDim = 3; + const int kElementPerDim = 2; + const int kBufSize = 8; + const int kNumArrays = 4; + static float in_buf[kNumArrays][kBufSize] = { + {0., 1., 2., 3., 4., 5., 6., 7.}, + {10., 11., 12., 13., 14., 15., 16., 17.}, + {20., 21., 22., 23., 24., 25., 26., 27.}, + {30., 31., 32., 33., 34., 35., 36., 37.}}; + int cnt = 0; + for (const string& concat_input_name : concat_input_names) { + Array& in_array = model->GetOrCreateArray(concat_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); + for (int i = 0; i < kDim; i++) { + in_array_shape_dim->push_back(kElementPerDim); + } + auto& in_array_buffer = + in_array.GetMutableBuffer(); + in_array_buffer.data.resize(kBufSize); + float* buf_ptr = + in_array.GetMutableBuffer().data.data(); + std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr); + cnt++; + } + auto* concatenation_op = new ConcatenationOperator; + concatenation_op->concat_dim = concat_dim; + concatenation_op->inputs = concat_input_names; + concatenation_op->outputs = {"concat_op_outputs"}; + Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim); + for (int i = 0; i < kDim; i++) { + if (i == concat_dim) { + (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim; + } else { + (*out_array_shape_dim)[i] = kElementPerDim; + } + } + model->operators.push_back(std::unique_ptr(concatenation_op)); + } +}; + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { + Model model; + const int concat_dim = 0; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., + 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25., + 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.}))); +} + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { + Model model; + const int concat_dim = 1; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22., + 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15., + 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.}))); +} + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { + Model model; + const int concat_dim = 2; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12., + 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25., + 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.}))); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc new file mode 100644 index 0000000000..4e273343df --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + // If a conv operation has an im2col array, yield: it should be dropped first. + if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) { + return false; + } + + Operator* ac_op = nullptr; + switch (op->fused_activation_function) { + case FusedActivationFunctionType::kRelu: + ac_op = new ReluOperator; + break; + case FusedActivationFunctionType::kRelu6: + ac_op = new Relu6Operator; + break; + case FusedActivationFunctionType::kRelu1: + ac_op = new Relu1Operator; + break; + default: + return false; + } + + // At this point we know that the op has a fused activation function. At the + // moment that only happens with ops having a single output, may be + // relaxed in the future. + CHECK_EQ(op->outputs.size(), 1); + + // Emplace unfused activation function, drop the fused one. + model->operators.emplace(it + 1, ac_op); + op->fused_activation_function = FusedActivationFunctionType::kNone; + + // Wire up arrays, constructing a new intermediate array to connect the + // op to its new unfused activation function. + ac_op->outputs = op->outputs; + const string& tmp_array_name = + AvailableArrayName(*model, op->outputs[0] + "_unfused"); + CHECK(!model->arrays.count(tmp_array_name)); + model->GetOrCreateArray(tmp_array_name); + ac_op->inputs = {tmp_array_name}; + op->outputs = {tmp_array_name}; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc new file mode 100644 index 0000000000..c889149ada --- /dev/null +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -0,0 +1,1508 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "google/protobuf/map.h" +#include "google/protobuf/text_format.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_util.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +using tensorflow::AttrValue; +using tensorflow::DT_BOOL; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::DT_UINT8; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorProto; +using tensorflow::TensorShapeProto; + +namespace toco { +namespace { +bool HasAttr(const NodeDef& node, const string& attr_name) { + return node.attr().count(attr_name) > 0; +} + +const string& GetStringAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kS); + return attr.s(); +} + +int GetIntAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n" + << node.DebugString(); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kI); + return attr.i(); +} + +float GetFloatAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kF); + return attr.f(); +} + +bool GetBoolAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kB); + return attr.b(); +} + +tensorflow::DataType GetDataTypeAttr(const NodeDef& node, + const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kType); + return attr.type(); +} + +const TensorShapeProto& GetShapeAttr(const NodeDef& node, + const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kShape); + return attr.shape(); +} + +const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kTensor); + return attr.tensor(); +} + +const AttrValue::ListValue& GetListAttr(const NodeDef& node, + const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kList); + return attr.list(); +} + +ArrayDataType ConvertDataType(tensorflow::DataType dtype) { + if (dtype == DT_UINT8) + return ArrayDataType::kUint8; + else if (dtype == DT_FLOAT) + return ArrayDataType::kFloat; + else if (dtype == DT_BOOL) + return ArrayDataType::kBool; + else if (dtype == DT_INT32) + return ArrayDataType::kInt32; + else if (dtype == DT_INT64) + return ArrayDataType::kInt64; + else + LOG(INFO) << "Unsupported data type in placehoder op: " << dtype; + return ArrayDataType::kNone; +} + +void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< + tensorflow::TensorShapeProto_Dim>& input_dims, + Shape* shape) { + std::vector input_dims_only_sizes; + for (auto& d : input_dims) { + if (d.size() == 0) { + // Some TensorFlow shapes contain a 0 dim, effectively making + // them of flat size 0 even though they have other nonzero dims. + // This breaks our invariant, that array dims can't be 0. + // For now, tweaking this to record a 0-D shape instead. + input_dims_only_sizes.clear(); + break; + } + input_dims_only_sizes.push_back(d.size()); + } + *shape->mutable_dims() = input_dims_only_sizes; +} + +void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_FLOAT); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_float_data = + output_array->GetMutableBuffer().data; + output_float_data.resize(input_flat_size); + if (input_tensor.float_val_size()) { + for (int i = 0; i < input_tensor.float_val_size(); i++) { + output_float_data[i] = input_tensor.float_val(i); + } + } else if (input_tensor.tensor_content().size() == + input_flat_size * sizeof(float)) { + toco::port::CopyToBuffer(input_tensor.tensor_content(), + reinterpret_cast(output_float_data.data())); + } else { + LOG(FATAL) << "Neither input_content nor float_val have the right " + "dimensions for this float tensor."; + } +} + +void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_INT32); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_int_data = + output_array->GetMutableBuffer().data; + output_int_data.resize(input_flat_size); + if (input_tensor.int_val_size()) { + for (int i = 0; i < input_tensor.int_val_size(); i++) { + output_int_data[i] = input_tensor.int_val(i); + } + } else if (input_tensor.tensor_content().size() == + input_flat_size * sizeof(int32)) { + toco::port::CopyToBuffer(input_tensor.tensor_content(), + reinterpret_cast(output_int_data.data())); + } else { + LOG(FATAL) << "Neither input_content nor int_val have the right " + "dimensions for this int32 tensor."; + } +} + +void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_INT64); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_int_data = + output_array->GetMutableBuffer().data; + output_int_data.resize(input_flat_size); + if (input_tensor.int64_val_size()) { + for (int i = 0; i < input_tensor.int64_val_size(); i++) { + output_int_data[i] = input_tensor.int64_val(i); + } + } else if (input_tensor.tensor_content().size() == + input_flat_size * sizeof(int64)) { + toco::port::CopyToBuffer(input_tensor.tensor_content(), + reinterpret_cast(output_int_data.data())); + } else { + LOG(FATAL) << "Neither input_content nor int64_val have the right " + "dimensions for this int64 tensor."; + } +} + +// Count the number of inputs of a given node. If `drop_control_dependency` is +// true, count the number of non-control-dependency inputs. +size_t GetInputsCount(const NodeDef& node, bool drop_control_dependency) { + if (drop_control_dependency) { + for (size_t i = 0; i < node.input_size(); ++i) { + if (node.input(i)[0] == '^') { + LOG(INFO) << "Reached first control dependency input: " + << node.input(i); + return i; + } + } + return node.input_size(); + } else { + return node.input_size(); + } +} + +void ConvertConstOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Const"); + const auto& tensor = GetTensorAttr(node, "value"); + const auto dtype = GetDataTypeAttr(node, "dtype"); + + auto& array = model->GetOrCreateArray(node.name()); + array.data_type = dtype == DT_FLOAT + ? ArrayDataType::kFloat + : dtype == DT_INT32 + ? ArrayDataType::kInt32 + : dtype == DT_INT64 ? ArrayDataType::kInt64 + : ArrayDataType::kNone; + if (dtype == DT_FLOAT) { + ImportFloatArray(tensor, &array); + } else if (dtype == DT_INT32) { + ImportInt32Array(tensor, &array); + } else if (dtype == DT_INT64) { + ImportInt64Array(tensor, &array); + } else { + // do nothing, silently ignore the Const data. For example, there are consts + // of string type. We just make a dummy buffer to indicate that this array + // does not rely on external input. + array.GetMutableBuffer(); + } +} + +void ConvertConvOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Conv2D"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + + // We only support NHWC, which is the default data_format. + // So if data_format is not defined, we're all good. + if (node.attr().count("data_format")) { + CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); + } + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + + const auto& input_name = node.input(0); + const auto& weights_name = node.input(1); + const auto& reordered_weights_name = weights_name + "_reordered"; + // Check if a ReorderAxesOperator was already created for these weights + // (that happens when multiple layers share the same weights). + const Operator* existing_reorder = + GetOpWithOutput(*model, reordered_weights_name); + if (existing_reorder) { + // Check that it is safe to rely on the _reordered naming of the output + // array! + CHECK(existing_reorder->type == OperatorType::kReorderAxes); + } else { + // Create a new ReorderAxesOperator + auto* reorder = new ReorderAxesOperator; + reorder->inputs = {weights_name}; + reorder->outputs = {reordered_weights_name}; + reorder->input_axes_order = AxesOrder::kHWIO; + reorder->output_axes_order = AxesOrder::kOHWI; + model->operators.emplace_back(reorder); + } + auto* conv = new ConvOperator; + conv->inputs = {input_name, reordered_weights_name}; + conv->outputs = {node.name()}; + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + conv->stride_height = strides.i(1); + conv->stride_width = strides.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + conv->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + conv->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(conv); +} + +void ConvertDepthwiseConvOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "DepthwiseConv2dNative"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + + // We only support NHWC, which is the default data_format. + // So if data_format is not defined, we're all good. + if (node.attr().count("data_format")) { + CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); + } + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + + const auto& input_name = node.input(0); + const auto& weights_name = node.input(1); + const auto& reordered_weights_name = weights_name + "_reordered"; + // Check if a ReorderAxesOperator was already created for these weights + // (that happens when multiple layers share the same weights). + const Operator* existing_reorder = + GetOpWithOutput(*model, reordered_weights_name); + if (existing_reorder) { + // Check that it is safe to rely on the _reordered naming of the output + // array! + CHECK(existing_reorder->type == OperatorType::kReorderAxes); + } else { + // Create a new ReorderAxesOperator + auto* reorder = new ReorderAxesOperator; + reorder->inputs = {weights_name}; + reorder->outputs = {reordered_weights_name}; + reorder->input_axes_order = AxesOrder::kHWIM; + reorder->output_axes_order = AxesOrder::k1HWO; + model->operators.emplace_back(reorder); + } + auto* conv = new DepthwiseConvOperator; + conv->inputs = {input_name, reordered_weights_name}; + conv->outputs = {node.name()}; + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + conv->stride_height = strides.i(1); + conv->stride_width = strides.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + conv->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + conv->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(conv); +} + +void ConvertDepthToSpaceOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "DepthToSpace"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* op = new DepthToSpaceOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + op->block_size = GetIntAttr(node, "block_size"); + QCHECK_GE(op->block_size, 2); + model->operators.emplace_back(op); +} + +void ConvertSpaceToDepthOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "SpaceToDepth"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* op = new SpaceToDepthOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + op->block_size = GetIntAttr(node, "block_size"); + QCHECK_GE(op->block_size, 2); + model->operators.emplace_back(op); +} + +void ConvertBiasAddOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "BiasAdd"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + const auto& input_name = node.input(0); + const auto& bias_name = node.input(1); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* biasadd = new AddOperator; + biasadd->inputs.push_back(input_name); + biasadd->inputs.push_back(bias_name); + biasadd->outputs.push_back(node.name()); + model->operators.emplace_back(biasadd); +} + +void ConvertReluOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Relu"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* relu = new ReluOperator; + relu->inputs.push_back(input_name); + relu->outputs.push_back(node.name()); + model->operators.emplace_back(relu); +} + +void ConvertRelu6Operator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Relu6"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* op = new Relu6Operator; + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertLogisticOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sigmoid"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* op = new LogisticOperator; + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertTanhOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Tanh"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* op = new TanhOperator; + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertDivOperator(const NodeDef& node, Model* model) { + CHECK(node.op() == "Div" || node.op() == "RealDiv"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new DivOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertIdentityOperator(const NodeDef& node, Model* model) { + CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || + node.op() == "PlaceholderWithDefault"); + auto* op = new TensorFlowIdentityOperator; + // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have + // identity nodes with multiple inputs, but the other inputs seem + // to be gratuitous (in the case of rajeev_lstm.pb, these are + // enumerating the LSTM state arrays). We will just ignore extra + // inputs beyond the first input. + CHECK_GE(node.input_size(), 1); + const auto& input_name = node.input(0); + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertFakeQuantWithMinMaxArgs(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new FakeQuantOperator; + op->inputs.push_back(node.input(0)); + op->minmax.reset(new MinMax); + auto& minmax = *op->minmax; + minmax.min = GetFloatAttr(node, "min"); + minmax.max = GetFloatAttr(node, "max"); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertFakeQuantWithMinMaxVars(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + CHECK(num_inputs == 3 || num_inputs == 4); + auto* op = new FakeQuantOperator; + for (int i = 0; i < 3; i++) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertRsqrtOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Rsqrt"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowRsqrtOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSqrtOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sqrt"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowSqrtOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSqueezeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Squeeze"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new SqueezeOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + + const auto& squeeze_dims = GetListAttr(node, "squeeze_dims"); + for (int i = 0; i < squeeze_dims.i_size(); ++i) { + op->squeeze_dims.push_back(squeeze_dims.i(i)); + } + + model->operators.emplace_back(op); +} + +void ConvertSquareOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Square"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowSquareOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertAddOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Add"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new AddOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMulOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Mul"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new MulOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSubOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sub"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new SubOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSumOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sum"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowSumOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertTileOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Tile"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowTileOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSliceOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Slice"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3); + auto* op = new SliceOperator; + for (int i = 0; i < 3; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertPadOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Pad"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new PadOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertShapeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Shape"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowShapeOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSplitOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Split"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowSplitOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + const int num_split = GetIntAttr(node, "num_split"); + op->outputs.push_back(node.name()); + for (int i = 1; i < num_split; i++) { + op->outputs.push_back(absl::StrCat(node.name(), ":", i)); + } + op->num_split = num_split; + model->operators.emplace_back(op); +} + +void ConvertMergeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Merge"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMergeOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSwitchOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Switch"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowSwitchOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + // Switch operators have two outputs: "name" and "name:1". + op->outputs.push_back(node.name() + ":1"); + model->operators.emplace_back(op); +} +void ConvertSoftmaxOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Softmax"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* softmax = new SoftmaxOperator; + softmax->inputs.push_back(input_name); + softmax->outputs.push_back(node.name()); + // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter. + CHECK(!node.attr().count("beta")); // Stab in the dark, just in case. + softmax->beta = 1.f; + model->operators.emplace_back(softmax); +} + +void ConvertLRNOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "LRN"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* lrn = new LocalResponseNormalizationOperator; + lrn->inputs.push_back(input_name); + lrn->outputs.push_back(node.name()); + lrn->range = GetIntAttr(node, "depth_radius"); + lrn->bias = GetFloatAttr(node, "bias"); + lrn->alpha = GetFloatAttr(node, "alpha"); + lrn->beta = GetFloatAttr(node, "beta"); + model->operators.emplace_back(lrn); +} + +void ConvertMaxPoolOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "MaxPool"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + if (HasAttr(node, "T")) { + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + } else { + LOG(WARNING) << "Found MaxPool operator missing 'T' attribute"; + } + auto* maxpool = new MaxPoolOperator; + maxpool->inputs.push_back(input_name); + maxpool->outputs.push_back(node.name()); + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + maxpool->stride_height = strides.i(1); + maxpool->stride_width = strides.i(2); + const auto& ksize = GetListAttr(node, "ksize"); + CHECK_EQ(ksize.i_size(), 4); + CHECK_EQ(ksize.i(0), 1); + CHECK_EQ(ksize.i(3), 1); + maxpool->kheight = ksize.i(1); + maxpool->kwidth = ksize.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + maxpool->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + maxpool->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(maxpool); +} + +void ConvertAvgPoolOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "AvgPool"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* avgpool = new AveragePoolOperator; + avgpool->inputs.push_back(input_name); + avgpool->outputs.push_back(node.name()); + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + avgpool->stride_height = strides.i(1); + avgpool->stride_width = strides.i(2); + const auto& ksize = GetListAttr(node, "ksize"); + CHECK_EQ(ksize.i_size(), 4); + CHECK_EQ(ksize.i(0), 1); + CHECK_EQ(ksize.i(3), 1); + avgpool->kheight = ksize.i(1); + avgpool->kwidth = ksize.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + avgpool->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + avgpool->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(avgpool); +} + +void ConvertReshapeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Reshape"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowReshapeOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMatMulOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "MatMul"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + // Transpose flags should be easy to support, but we don't have a + // GraphDef with them to test on at the moment. + CHECK_EQ(GetBoolAttr(node, "transpose_a"), false); + CHECK_EQ(GetBoolAttr(node, "transpose_b"), false); + const auto& input_name = node.input(0); + const auto& weights_name = node.input(1); + const auto& reordered_weights_name = weights_name + "_reordered"; + // Check if a ReorderAxesOperator was already created for these weights + // (that happens when multiple layers share the same weights). + const Operator* existing_reorder = + GetOpWithOutput(*model, reordered_weights_name); + if (existing_reorder) { + // Check that it is safe to rely on the _reordered naming of the output + // array! + CHECK(existing_reorder->type == OperatorType::kReorderAxes); + } else { + // Create a new ReorderAxesOperator + auto* reorder = new ReorderAxesOperator; + reorder->inputs = {weights_name}; + reorder->outputs = {reordered_weights_name}; + reorder->input_axes_order = AxesOrder::kRC; + reorder->output_axes_order = AxesOrder::kCR; + model->operators.emplace_back(reorder); + } + auto* matmul = new TensorFlowMatMulOperator; + matmul->inputs = {input_name, reordered_weights_name}; + matmul->outputs = {node.name()}; + model->operators.emplace_back(matmul); +} + +void ConvertConcatOperator(const NodeDef& node, Model* model) { + Operator* op = nullptr; + if (node.op() == "Concat") { + op = new TensorFlowConcatOperator; + } else if (node.op() == "ConcatV2") { + op = new TensorFlowConcatV2Operator; + } else { + LOG(FATAL) << "Expected Concat or ConcatV2"; + } + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + CHECK_GE(num_inputs, 2); + CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N")); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertAllOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "All"); + auto* op = new TensorFlowAllOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertAssertOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Assert"); + auto* op = new TensorFlowAssertOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertLessOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Less"); + auto* op = new TensorFlowLessOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertLessEqualOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "LessEqual"); + auto* op = new TensorFlowLessEqualOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertGreaterOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Greater"); + auto* op = new TensorFlowGreaterOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertGreaterEqualOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "GreaterEqual"); + auto* op = new TensorFlowGreaterEqualOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMaxOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Max"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMaxOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMinOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Min"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMinOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMaximumOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Maximum"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMaximumOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMinimumOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Minimum"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMinimumOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertUnsupportedOperator(const NodeDef& node, Model* model) { + LOG(INFO) << "Converting unsupported operation: " << node.op(); + auto* op = new TensorFlowUnsupportedOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + op->tensorflow_op = node.op(); + node.SerializeToString(&op->tensorflow_node_def); + model->operators.emplace_back(op); + if (HasAttr(node, "_output_quantized")) { + op->quantized = GetBoolAttr(node, "_output_quantized"); + } + if (HasAttr(node, "_output_types")) { + const auto& output_types = GetListAttr(node, "_output_types"); + for (int i = 0; i < output_types.type_size(); ++i) { + op->output_data_types.push_back(ConvertDataType(output_types.type(i))); + } + } +} + +void ConvertStridedSliceOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "StridedSlice"); + CHECK_EQ(node.input_size(), 4); + + // Only a subset of the full TF op functionality is supported now. + if ( // No 64-bit indices. + GetDataTypeAttr(node, "Index") != DT_INT32 || + // No dimensionality changes. + GetIntAttr(node, "new_axis_mask") != 0 || + GetIntAttr(node, "shrink_axis_mask") != 0 || + // No sparse indices. + GetIntAttr(node, "ellipsis_mask") != 0 || + // Only 4D tensors are supported. + GetIntAttr(node, "begin_mask") > 15 || + GetIntAttr(node, "end_mask") > 15) { + ConvertUnsupportedOperator(node, model); + return; + } + + auto* op = new StridedSliceOperator; + for (const auto& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->begin_mask = GetIntAttr(node, "begin_mask"); + op->ellipsis_mask = GetIntAttr(node, "ellipsis_mask"); + op->end_mask = GetIntAttr(node, "end_mask"); + op->new_axis_mask = GetIntAttr(node, "new_axis_mask"); + op->shrink_axis_mask = GetIntAttr(node, "shrink_axis_mask"); + model->operators.emplace_back(op); +} + +void ConvertPlaceholderOperator(const NodeDef& node, Model* model) { + CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); + if (node.op() == "Placeholder") { + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 0); + } + auto& array = model->GetOrCreateArray(node.name()); + if (node.attr().count("dtype")) { + array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype")); + } + if (node.attr().count("shape")) { + const auto& shape = GetShapeAttr(node, "shape"); + auto num_dims = shape.dim_size(); + bool has_wildcard = false; + for (std::size_t i = 0; i < num_dims; i++) { + if (shape.dim(i).size() == -1) { + has_wildcard = true; + } + } + // TODO(b/62716978): This logic needs to be revisted. During dims + // refactoring it is an interim fix. + if (num_dims > 0 && !has_wildcard) { + auto& dst_array_dims = *array.mutable_shape()->mutable_dims(); + dst_array_dims.resize(num_dims); + for (std::size_t i = 0; i < num_dims; i++) { + dst_array_dims[i] = shape.dim(i).size(); + } + } + } +} + +void ConvertNoOpOperator(const NodeDef& node, Model* model) {} + +ArrayDataType GetArrayDataType(tensorflow::DataType tf_data_type) { + if (tf_data_type == DT_UINT8) { + return ArrayDataType::kUint8; + } else if (tf_data_type == DT_INT32) { + return ArrayDataType::kInt32; + } else if (tf_data_type == DT_FLOAT) { + return ArrayDataType::kFloat; + } else { + return ArrayDataType::kNone; + } +} + +void ConvertCastOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Cast"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); + const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT"); + CHECK(tf_src_dtype == DT_UINT8 || tf_src_dtype == DT_INT32 || + tf_src_dtype == DT_FLOAT); + CHECK(tf_dst_dtype == DT_UINT8 || tf_dst_dtype == DT_INT32 || + tf_dst_dtype == DT_FLOAT); + CHECK_NE(tf_src_dtype, tf_dst_dtype) + << "Same input and output data type. No need to cast."; + auto* op = new CastOperator; + op->src_data_type = GetArrayDataType(tf_src_dtype); + op->dst_data_type = GetArrayDataType(tf_dst_dtype); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertFloorOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Floor"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto data_type = GetDataTypeAttr(node, "T"); + CHECK(data_type == DT_FLOAT); + auto* op = new FloorOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertGatherOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Gather"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); + CHECK(indices_data_type == DT_INT32); + auto* op = new GatherOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertResizeBilinearOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "ResizeBilinear"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new ResizeBilinearOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef& node, + Model* model) { + CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 5); + + // TODO(ahentz): to really match tensorflow we need to add variance_epsilon + // to the input, before feeding it into TensorFlowRsqrtOperator. + // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f); + + string multiplier = node.name() + "_mul"; + if (GetBoolAttr(node, "scale_after_normalization")) { + // Create graph: + // v -> RSQRT -> + // MUL -> multiplier + // gamma -----> + string rsqrt = node.name() + "_rsqrt"; + + auto* rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs.push_back(node.input(2)); + rsqrt_op->outputs.push_back(rsqrt); + model->operators.emplace_back(rsqrt_op); + + auto* mul_op = new MulOperator; + mul_op->inputs.push_back(rsqrt); + mul_op->inputs.push_back(node.input(4)); + mul_op->outputs.push_back(multiplier); + model->operators.emplace_back(mul_op); + } else { + // Create graph: + // v -> RSQRT -> multiplier + auto* rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs.push_back(node.input(2)); + rsqrt_op->outputs.push_back(multiplier); + model->operators.emplace_back(rsqrt_op); + } + + auto* op = new BatchNormalizationOperator; + op->global_normalization = true; + + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(multiplier); + op->inputs.push_back(node.input(3)); + op->outputs.push_back(node.name()); + + model->operators.emplace_back(op); +} + +void ConvertFusedBatchNormOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "FusedBatchNorm"); + CHECK_EQ(node.input_size(), 5); + + // Declare shortcuts for the inputs. + const string& gamma_input = node.input(1); + const string& beta_input = node.input(2); + const string& moving_mean_input = node.input(3); + const string& moving_variance_input = node.input(4); + + // Create an array holding the epsilon value (typically, 0.001). + const string epsilon_array_name = node.name() + "_epsilon_array"; + auto& epsilon_array = model->GetOrCreateArray(epsilon_array_name); + epsilon_array.data_type = ArrayDataType::kFloat; + *epsilon_array.mutable_shape()->mutable_dims() = {1}; + epsilon_array.GetMutableBuffer().data.push_back( + GetFloatAttr(node, "epsilon")); + + // Add epsilon to the moving variance. + const string epsilon_add_op_name = node.name() + "_epsilon"; + auto* epsilon_add_op = new AddOperator; + epsilon_add_op->inputs.push_back(moving_variance_input); + epsilon_add_op->inputs.push_back(epsilon_array_name); + epsilon_add_op->outputs.push_back(epsilon_add_op_name); + model->operators.emplace_back(epsilon_add_op); + + // Take the inverse square root of the (variance + epsilon). + const string rsqrt_op_name = node.name() + "_rsqrt"; + auto* rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs.push_back(epsilon_add_op_name); + rsqrt_op->outputs.push_back(rsqrt_op_name); + model->operators.emplace_back(rsqrt_op); + + // Multiply the result by gamma. + const string multiplier = node.name() + "_mul"; + auto* mul_op = new MulOperator; + mul_op->inputs.push_back(rsqrt_op_name); + mul_op->inputs.push_back(gamma_input); + mul_op->outputs.push_back(multiplier); + model->operators.emplace_back(mul_op); + + // Now we have all required inputs for the BatchNormalizationOperator. + auto* op = new BatchNormalizationOperator; + op->global_normalization = true; + + op->inputs.push_back(node.input(0)); + op->inputs.push_back(moving_mean_input); + op->inputs.push_back(multiplier); + op->inputs.push_back(beta_input); + op->outputs.push_back(node.name()); + + model->operators.emplace_back(op); +} + +void ConvertSpaceToBatchNDOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "SpaceToBatchND"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3); + CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); + CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32); + auto* op = new SpaceToBatchNDOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertBatchToSpaceNDOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "BatchToSpaceND"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3); + CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); + CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32); + auto* op = new BatchToSpaceNDOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMeanOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Mean"); + CHECK_EQ(node.input_size(), 2); + auto* op = new MeanOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSvdfOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Svdf"); + bool has_bias = (node.input_size() == 4); + auto* op = new SvdfOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + if (has_bias) { + op->inputs.push_back(node.input(3)); + } + op->outputs.push_back(node.name() + "_state"); + op->outputs.push_back(node.name()); + if (node.attr().at("ActivationFunction").s() == "Relu") { + op->fused_activation_function = FusedActivationFunctionType::kRelu; + } else { + op->fused_activation_function = FusedActivationFunctionType::kNone; + } + op->rank = node.attr().at("Rank").i(); + model->operators.emplace_back(op); +} + +void StripCaretFromArrayNames(Model* model) { + for (auto& op : model->operators) { + for (auto& input : op->inputs) { + input = string(absl::StripPrefix(input, "^")); + } + for (auto& output : op->outputs) { + output = string(absl::StripPrefix(output, "^")); + } + } + for (auto& array : model->arrays) { + if (absl::StartsWith(array.first, "^")) { + LOG(FATAL) << "What?"; + } + } +} + +void AddExtraOutputsFedIntoOtherOps(Model* model) { + for (const auto& consumer_op : model->operators) { + for (const string& input : consumer_op->inputs) { + const std::vector& split = absl::StrSplit(input, ':'); + if (split.size() != 2) { + continue; + } + int output_index = 0; + if (!absl::SimpleAtoi(split[1], &output_index)) { + continue; + } + auto* producer_op = GetOpWithOutput(*model, split[0]); + if (!producer_op) { + continue; + } + while (producer_op->outputs.size() <= output_index) { + using toco::port::StringF; + producer_op->outputs.push_back( + StringF("%s:%d", split[0], producer_op->outputs.size())); + } + } + } +} + +bool InlineAllFunctions(GraphDef* graphdef) { + if (graphdef->library().function().empty()) { + VLOG(kLogLevelModelUnchanged) << "No functions to inline."; + return false; + } + + // Override "_noinline" attribute on all functions + GraphDef graphdef_copy(*graphdef); + for (auto& function : + (*graphdef_copy.mutable_library()->mutable_function())) { + auto* attributes = function.mutable_attr(); + if (attributes->count(tensorflow::kNoInlineAttr) != 0) { + (*attributes)[tensorflow::kNoInlineAttr].set_b(false); + } + } + + // Construct minimum resources needed to use ExpandInlineFunctions(). + tensorflow::SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + std::vector devices; + TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); + + tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(), + graphdef_copy.library()); + tensorflow::DeviceMgr device_mgr(devices); + tensorflow::OptimizerOptions o_opts; + tensorflow::ProcessFunctionLibraryRuntime pflr( + &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld, + o_opts, nullptr); + tensorflow::FunctionLibraryRuntime* flr; + flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + + tensorflow::Graph graph(fld); + tensorflow::GraphConstructorOptions gc_opts; + TF_CHECK_OK( + tensorflow::ConvertGraphDefToGraph(gc_opts, graphdef_copy, &graph)); + + // Iterate over the graph until there are no more nodes to be inlined. + bool graph_modified = false; + while (tensorflow::ExpandInlineFunctions(flr, &graph)) { + graph_modified = true; + LOG(INFO) << "Found functions that were inlined."; + } + + // Output inlined graph + if (graph_modified) { + graph.ToGraphDef(graphdef); + } + return graph_modified; +} +} // namespace + +std::unique_ptr ImportTensorFlowGraphDef(const ModelFlags& model_flags, + const GraphDef& tf_graph) { + LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph); + + GraphDef inlined_graph(tf_graph); + if (InlineAllFunctions(&inlined_graph)) { + LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph); + } + + Model* model = new Model; + ResolveModelFlags(model_flags, model); + + for (const auto& node : inlined_graph.node()) { + if (node.op() == "Const") { + ConvertConstOperator(node, model); + } else if (node.op() == "Conv2D") { + ConvertConvOperator(node, model); + } else if (node.op() == "DepthwiseConv2dNative") { + ConvertDepthwiseConvOperator(node, model); + } else if (node.op() == "DepthToSpace") { + ConvertDepthToSpaceOperator(node, model); + } else if (node.op() == "SpaceToDepth") { + ConvertSpaceToDepthOperator(node, model); + } else if (node.op() == "BiasAdd") { + ConvertBiasAddOperator(node, model); + } else if (node.op() == "Relu") { + ConvertReluOperator(node, model); + } else if (node.op() == "Relu6") { + ConvertRelu6Operator(node, model); + } else if (node.op() == "Sigmoid") { + ConvertLogisticOperator(node, model); + } else if (node.op() == "Tanh") { + ConvertTanhOperator(node, model); + } else if (node.op() == "MaxPool") { + ConvertMaxPoolOperator(node, model); + } else if (node.op() == "AvgPool") { + ConvertAvgPoolOperator(node, model); + } else if (node.op() == "Reshape") { + ConvertReshapeOperator(node, model); + } else if (node.op() == "MatMul") { + ConvertMatMulOperator(node, model); + } else if (node.op() == "Div" || node.op() == "RealDiv") { + ConvertDivOperator(node, model); + } else if (node.op() == "Identity" || node.op() == "CheckNumerics") { + ConvertIdentityOperator(node, model); + } else if (node.op() == "FakeQuantWithMinMaxVars") { + ConvertFakeQuantWithMinMaxVars(node, model); + } else if (node.op() == "FakeQuantWithMinMaxArgs") { + ConvertFakeQuantWithMinMaxArgs(node, model); + } else if (node.op() == "Rsqrt") { + ConvertRsqrtOperator(node, model); + } else if (node.op() == "Squeeze") { + ConvertSqueezeOperator(node, model); + } else if (node.op() == "Sqrt") { + ConvertSqrtOperator(node, model); + } else if (node.op() == "Square") { + ConvertSquareOperator(node, model); + } else if (node.op() == "Add") { + ConvertAddOperator(node, model); + } else if (node.op() == "Mul") { + ConvertMulOperator(node, model); + } else if (node.op() == "Sub") { + ConvertSubOperator(node, model); + } else if (node.op() == "Sum") { + ConvertSumOperator(node, model); + } else if (node.op() == "Tile") { + ConvertTileOperator(node, model); + } else if (node.op() == "Concat" || node.op() == "ConcatV2") { + ConvertConcatOperator(node, model); + } else if (node.op() == "LRN") { + ConvertLRNOperator(node, model); + } else if (node.op() == "Softmax") { + ConvertSoftmaxOperator(node, model); + } else if (node.op() == "All") { + ConvertAllOperator(node, model); + } else if (node.op() == "Assert") { + ConvertAssertOperator(node, model); + } else if (node.op() == "Less") { + ConvertLessOperator(node, model); + } else if (node.op() == "LessEqual") { + ConvertLessEqualOperator(node, model); + } else if (node.op() == "Greater") { + ConvertGreaterOperator(node, model); + } else if (node.op() == "GreaterEqual") { + ConvertGreaterEqualOperator(node, model); + } else if (node.op() == "Max") { + ConvertMaxOperator(node, model); + } else if (node.op() == "Min") { + ConvertMinOperator(node, model); + } else if (node.op() == "Maximum") { + ConvertMaximumOperator(node, model); + } else if (node.op() == "Minimum") { + ConvertMinimumOperator(node, model); + } else if (node.op() == "Merge") { + ConvertMergeOperator(node, model); + } else if (node.op() == "Pad") { + ConvertPadOperator(node, model); + } else if (node.op() == "StridedSlice") { + ConvertStridedSliceOperator(node, model); + } else if (node.op() == "Shape") { + ConvertShapeOperator(node, model); + } else if (node.op() == "Slice") { + ConvertSliceOperator(node, model); + } else if (node.op() == "Split") { + ConvertSplitOperator(node, model); + } else if (node.op() == "Switch") { + ConvertSwitchOperator(node, model); + } else if (node.op() == "Placeholder") { + ConvertPlaceholderOperator(node, model); + } else if (node.op() == "PlaceholderWithDefault") { + ConvertIdentityOperator(node, model); + } else if (node.op() == "LegacyFedInput") { + ConvertPlaceholderOperator(node, model); + } else if (node.op() == "NoOp") { + ConvertNoOpOperator(node, model); + } else if (node.op() == "Cast") { + ConvertCastOperator(node, model); + } else if (node.op() == "Floor") { + ConvertFloorOperator(node, model); + } else if (node.op() == "Gather") { + ConvertGatherOperator(node, model); + } else if (node.op() == "ResizeBilinear") { + ConvertResizeBilinearOperator(node, model); + } else if (node.op() == "BatchNormWithGlobalNormalization") { + ConvertBatchNormWithGlobalNormalizationOperator(node, model); + } else if (node.op() == "FusedBatchNorm") { + ConvertFusedBatchNormOperator(node, model); + } else if (node.op() == "SpaceToBatchND") { + ConvertSpaceToBatchNDOperator(node, model); + } else if (node.op() == "BatchToSpaceND") { + ConvertBatchToSpaceNDOperator(node, model); + } else if (node.op() == "Mean") { + ConvertMeanOperator(node, model); + } else if (node.op() == "Svdf") { + ConvertSvdfOperator(node, model); + } else { + ConvertUnsupportedOperator(node, model); + } + } + + StripCaretFromArrayNames(model); + AddExtraOutputsFedIntoOtherOps(model); + FixNoMissingArray(model); + FixNoOrphanedArray(model); + FixOperatorOrdering(model); + CheckInvariants(*model); + + // if rnn state arrays are constant, make them transient + for (const auto& rnn_state : model->flags.rnn_states()) { + model->GetArray(rnn_state.state_array()).buffer = nullptr; + } + + return std::unique_ptr(model); +} + +std::unique_ptr ImportTensorFlowGraphDef( + const ModelFlags& model_flags, const string& input_file_contents) { + std::unique_ptr tf_graph(new GraphDef); + CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get())); + + std::unique_ptr pruned_graph = + MaybeReplaceCompositeSubgraph(*tf_graph); + if (pruned_graph) { + tf_graph = std::move(pruned_graph); + } + return ImportTensorFlowGraphDef(model_flags, *tf_graph); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h new file mode 100644 index 0000000000..d2eb423ca4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/import_tensorflow.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ + +#include +#include +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace toco { + +std::unique_ptr ImportTensorFlowGraphDef( + const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def); + +std::unique_ptr ImportTensorFlowGraphDef( + const ModelFlags& model_flags, const string& input_file_contents); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h new file mode 100644 index 0000000000..d992f8458f --- /dev/null +++ b/tensorflow/contrib/lite/toco/model.h @@ -0,0 +1,1372 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +enum class OperatorType { + kNone, + // General-purpose neural network operators. + kAdd, + kAveragePool, + kBatchNormalization, + kConv, + kConcatenation, + kDepthwiseConv, + kDepthToSpace, + kSpaceToDepth, + kDequantize, + kDiv, + kFullyConnected, + kL2Normalization, + kL2Pool, + kLstmCell, + kLocalResponseNormalization, + kLogistic, + kMaxPool, + kFakeQuant, + kMul, + kRelu, + kRelu1, + kRelu6, + kSoftmax, + kSub, + kTanh, + kCast, + kFloor, + kGather, + kResizeBilinear, + kSpaceToBatchND, + kBatchToSpaceND, + kPad, + kStridedSlice, + kSlice, + kSqueeze, + kMean, + // The SVDF Op is a decomposition of a densely connected Op into + // low rank filters. For details: + // https://research.google.com/pubs/pub43813.html + kSvdf, + // Special operators used for importing TensorFlow nodes. + // The general intent is to have some graph transformation either + // drop them or rewrite them as general-purpose operators. + kTensorFlowAll, + kTensorFlowAssert, + kTensorFlowConcat, + kTensorFlowConcatV2, + kTensorFlowGreater, + kTensorFlowGreaterEqual, + kTensorFlowIdentity, + kTensorFlowLess, + kTensorFlowLessEqual, + kTensorFlowMax, + kTensorFlowMaximum, + kTensorFlowMin, + kTensorFlowMinimum, + kTensorFlowMatMul, + kTensorFlowMerge, + kTensorFlowReshape, + kTensorFlowRsqrt, + kTensorFlowShape, + kTensorFlowSplit, + kTensorFlowSqrt, + kTensorFlowSquare, + kTensorFlowSum, + kTensorFlowSwitch, + kTensorFlowTile, + // An unsupported TF operation. It's only needed to be able to represent TF + // graph internally and is expected to be dropped by graph transformations. + kTensorFlowUnsupported, + // Finally, TensorFlow uses different conventions for axes ordering, + // see AxesOrder, and this cannot always be resolved at the time of importing + // nodes, as TensorFlow parameters may be constant-expression subgraphs + // instead of being given as plain constant arrays. So we need to insert + // special nodes in the graph to shuffle axes. + kReorderAxes, +}; + +// Helper to deal with TensorFlow arrays using a different ordering of +// dimensions +// ("axes") than our own. +// TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes, +// we should have associative arrays mapping symbolic axes identifiers (like +// "output_depth") to dimensions. We would then not need this anymore. +enum class AxesOrder { + kOneAxis, // one-dimensional array, one unique axis. + kCR, // column-major matrix storage order. Our standard. + kRC, // row-major matrix storage order. TensorFlow default. + kOHWI, // Our standard for conv weights + kHWIO, // TensorFlow conv weights + k1HWO, // Our standard for DepthwiseConv weights + kHWIM, // TensorFlow DepthwiseConv weights + kNHWC, // TensorFlow activations +}; + +// The type of the scalars in an array. +// Note that that does not by itself tell whether the values in the array are +// real (are literally interpreted as real numbers) or quantized (only acquire +// a meaning as real numbers in conjuction with QuantizationParams). +// +// In practice though: +// float values are always real +// uint8 values are always quantized +// int32 values are either real or quantized (depending on whether +// QuantizationParams are present). +// other types are unused at the moment. +// +// kNone means that we don't know the data type yet, or that we don't care +// because we'll be dropping the array anyway (e.g. some exotic array types +// may be involved only in debug-only subgraphs that we may not be interested +// in actually supporting). +enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 }; + +// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type +template +struct DataTypeImpl {}; +template <> +struct DataTypeImpl { + typedef int Type; +}; +template <> +struct DataTypeImpl { + typedef bool Type; +}; +template <> +struct DataTypeImpl { + typedef float Type; +}; +template <> +struct DataTypeImpl { + typedef uint8 Type; +}; +template <> +struct DataTypeImpl { + typedef int32 Type; +}; +template <> +struct DataTypeImpl { + typedef int64 Type; +}; + +template +using DataType = typename DataTypeImpl::Type; + +// Base class for type-specific buffer types. +struct GenericBuffer { + // Non-default-constructible: only ArrayDataType-specific subclass + // objects may be constructed. + GenericBuffer() = delete; + // Non-copyable-or-movable: we should only store pointers-to-Buffer + // in containers, not Operators themselves, so there should be no + // copy or move. + GenericBuffer(const GenericBuffer&) = delete; + GenericBuffer(const GenericBuffer&&) = delete; + + // We need a virtual destructor so we can store pointers-to-Buffer + // in containers and have the containers call the right subclass destructor. + virtual ~GenericBuffer() {} + + const ArrayDataType type; + + protected: + // Constructor used by subclasses for specific ArrayDataType's. + explicit GenericBuffer(ArrayDataType t) : type(t) {} +}; + +// Type-specific buffer, containing type-specific storage. +template +struct Buffer : GenericBuffer { + Buffer() : GenericBuffer(A) {} + + std::vector> data; +}; + +// Base class for all operator classes. +struct Operator { + // Non-default-constructible: only OperatorType-specific subclass + // objects may be constructed. + Operator() = delete; + // Non-copyable-or-movable: we should only store pointers-to-Operator + // in containers, not Operators themselves, so there should be no + // copy or move. + Operator(const Operator&) = delete; + Operator(const Operator&&) = delete; + + // We need a virtual destructor so we can store pointers-to-Operator + // in containers and have the containers call the right subclass destructor. + virtual ~Operator() {} + + // The specific type of operator. Corresponds 1:1 to subclasses. + const OperatorType type; + + // The activation function that may be fused into this operator, + // or None if no activation function is fused. + FusedActivationFunctionType fused_activation_function; + + // Input arrays: either activation arrays or constant array parameters. + // We refer to them by their name, not by their address; the mapping of + // names to addresses is given by the Model, which owns both Operator's and + // Array's. Thus, an Operator on its own doesn't contain much information, + // it is meant to be used in conjunction with the Model that owns it. + std::vector inputs; + + // Output activation arrays. Same comments as for inputs apply here too. + std::vector outputs; + + // If true, the array has more outputs than are listed in the 'outputs' + // member. These need to be resolved by some graph transformation. + // This flag is only here to indicate that an operator should not be + // discarded as unused, even if from its 'outputs' member alone it + // looks unused. + bool unresolved_outputs = false; + + protected: + // Constructor used by subclasses for specific OperatorType's. + explicit Operator(OperatorType t) + : type(t), + fused_activation_function(FusedActivationFunctionType::kNone) {} +}; + +// Padding types for Conv-like operators. This is how padding is typically +// specified in model files. But for inference, we will need to resolve this +// to a FixedPadding, see below. +enum class PaddingType { kNone, kSame, kValid }; + +// Padding as resolved for a specific layer shape, as needed for inference. +// For a given layer shape, a given padding type will resolve to a choice of +// a number of padding rows and columns, which we call the padding height and +// width respectively. +struct FixedPadding { + int width = 0; + int height = 0; +}; + +// "Universal" padding struct containing both a generic PaddingType (as +// represented in a model file), and a FixedPadding (as needed for inference). +// The latter is resolved during the PropagateFixedSizes pass. +struct Padding { + FixedPadding& GetOrCreateFixedPadding() { + if (!fixed) { + FixedPadding* ptr = new FixedPadding; + fixed = std::unique_ptr(ptr); + } + return *fixed; + } + + Padding() : type(PaddingType::kNone) {} + PaddingType type; + std::unique_ptr fixed; +}; + +// "Convolutional" layer, as represented in model files. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the Conv weights +// inputs[2]: optional: the bias vector, specifying the biases for each output +// channel. +// +// Outputs: +// outputs[0]: required: the output activations array +// outputs[1]: optional: the intermediate array of im2col-replicated input +// activations. Present when targeting implementations +// of Conv layers as Im2col+GEMM. +// +// TensorFlow equivalent: Conv2D +struct ConvOperator : Operator { + ConvOperator() : Operator(OperatorType::kConv) {} + Padding padding; + int stride_width = 0; + int stride_height = 0; +}; + +// Depthwise-separable convolution operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the DepthwiseConv weights +// inputs[2]: optional: the bias vector, specifying the biases for each output +// channel. +// +// TensorFlow equivalent: DepthwiseConv2dNative +struct DepthwiseConvOperator : Operator { + DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int depth_multiplier = 0; +}; + +// Depth-to-space transform operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// +// TensorFlow equivalent: DepthToSpace +struct DepthToSpaceOperator : Operator { + DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {} + int block_size = 0; +}; + +// Space-to-depth transform operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// +// TensorFlow equivalent: SpaceToDepth +struct SpaceToDepthOperator : Operator { + SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {} + int block_size = 0; +}; + +// Fully-connected operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the FullyConnected weights +// inputs[2]: optional: the bias vector, specifying the biases for each output +// channel. +// +// TensorFlow equivalent: a pair consisting of a Reshape node reshaping the +// input activations as a matrix, followed by a MatMul node. +struct FullyConnectedOperator : Operator { + FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {} +}; + +// Dequantization operator, converting a quantized array of integers with +// quantization parameters specifying how these integers correspond to real +// numbers +// (see QuantizationParams) to an output activations array of floating-point +// values. +// +// In floating-point image models, there is typically a Dequantization operator +// at the very beginning, converting the input image RGB data, consisting of +// uint8 integer values, to floating-point input activations. That is where +// image model parameters such as "mean_value" and "std_value" are typically +// handled. +// +// This is the only operator type that converts from quantized to +// floating-point, +// and there is at the moment no operator type at all to convert from +// floating-point +// to quantized. Every other operator does either float->float or +// quantized->quantized. +// +// Inputs: +// inputs[0]: required: the input quantized activations array +// +// TensorFlow equivalent: Dequantize +struct DequantizeOperator : Operator { + DequantizeOperator() : Operator(OperatorType::kDequantize) {} +}; + +// Batch-normalization operator. +// +// We only support batch-normalization using pre-learned moments, so this is +// just +// computing (input - mean) * multiplier + offset. As such, this can be +// expressed as a combination of Add and Mul nodes, and indeed this is how +// we break it down during tooling for the purpose of fusing it into +// other operators. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the learned mean array +// inputs[2]: required: the learned multiplier array +// inputs[3]: required: the learned offset array +// +// TensorFlow equivalent: a combination of Add and Mul nodes +struct BatchNormalizationOperator : Operator { + BatchNormalizationOperator() + : Operator(OperatorType::kBatchNormalization), + global_normalization(false) {} + bool global_normalization; +}; + +// L2-normalization operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// +// TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented +// by a sub-graph of operators implementing L2-normalization +// from lower-level arithmetic nodes; during tooling, we identify such +// sub-graphs +// and replace them by L2NormalizationOperator's. See IdentifyL2Normalization. +struct L2NormalizationOperator : Operator { + L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {} +}; + +// LSTM Cell operator. +// +// Inputs: +// inputs[0]: required: the input data array +// inputs[1]: required: the previous output activations array +// inputs[2]: required: the learned weights array +// inputs[3]: required: the learned biases array +// inputs[4]: required: the previous output state +// outputs[0]: required: the output activations array +// outputs[1]: required: the new state array +// +// TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented +// with a sub-graph of lower-level arithmetic nodes; during tooling, we identify +// such sub-graphs and replace them with LstmCells. See IdentifyLstmCell(). +struct LstmCellOperator : Operator { + enum Inputs { + DATA_INPUT = 0, + PREV_ACTIV_INPUT = 1, + WEIGHTS_INPUT = 2, + BIASES_INPUT = 3, + PREV_STATE_INPUT = 4, + NUM_INPUTS = 5 + }; + enum Outputs { + ACTIV_OUTPUT = 0, + STATE_OUTPUT = 1, + CONCAT_TEMP = 2, + ACTIV_TEMP = 3, + NUM_OUTPUTS = 4 + }; + LstmCellOperator() : Operator(OperatorType::kLstmCell) {} +}; + +// Element-wise multiplication operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Mul +struct MulOperator : Operator { + MulOperator() : Operator(OperatorType::kMul) {} +}; + +// Element-wise Relu operator: +// x -> max(0, x) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Relu +struct ReluOperator : Operator { + ReluOperator() : Operator(OperatorType::kRelu) {} +}; + +// Element-wise Relu1 operator: +// x -> min(max(x, -1), 1) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: none. We can construct the operator with Minimum +// and Maximum operations +struct Relu1Operator : Operator { + Relu1Operator() : Operator(OperatorType::kRelu1) {} +}; + +// Element-wise Relu6 operator: +// x -> max(0, min(6, x)) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Relu6 +struct Relu6Operator : Operator { + Relu6Operator() : Operator(OperatorType::kRelu6) {} +}; + +// Element-wise Logistic operator: +// x -> Logistic(x) = 1 / (1 + exp(-x)) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sigmoid +struct LogisticOperator : Operator { + LogisticOperator() : Operator(OperatorType::kLogistic) {} +}; + +// Element-wise Tanh operator: +// x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Tanh +struct TanhOperator : Operator { + TanhOperator() : Operator(OperatorType::kTanh) {} +}; + +// Element-wise addition operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Add +struct AddOperator : Operator { + AddOperator() : Operator(OperatorType::kAdd) {} +}; + +// Concatenation operator: concatenates its inputs +// along the concat_dim dimension. +// +// Inputs: this operator accepts any number >= 1 of inputs. +// inputs[i]: the i-th array to concatenate. +// +// TensorFlow equivalent: Concat. +struct ConcatenationOperator : Operator { + ConcatenationOperator() : Operator(OperatorType::kConcatenation) {} + int concat_dim = 0; +}; + +// Reordering dimensions. Used only during tooling to transform graphs from +// the TensorFlow format. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: none. This is only useful to convert between formats. +struct ReorderAxesOperator : Operator { + ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {} + AxesOrder input_axes_order; + AxesOrder output_axes_order; +}; + +// Average-pooling operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: AveragePool +struct AveragePoolOperator : Operator { + AveragePoolOperator() : Operator(OperatorType::kAveragePool) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int kheight = 0; + int kwidth = 0; +}; + +// Local response normalization operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: LRN +struct LocalResponseNormalizationOperator : Operator { + LocalResponseNormalizationOperator() + : Operator(OperatorType::kLocalResponseNormalization) {} + + int range = 0; + float bias = 0.f; + float alpha = 0.f; + float beta = 0.f; +}; + +// Max-pooling operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: MaxPool +struct MaxPoolOperator : Operator { + MaxPoolOperator() : Operator(OperatorType::kMaxPool) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int kheight = 0; + int kwidth = 0; +}; + +// L2-pooling operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt. +struct L2PoolOperator : Operator { + L2PoolOperator() : Operator(OperatorType::kL2Pool) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int kheight = 0; + int kwidth = 0; +}; + +// The expected [min, max] range of values in a given array. +// Used for quantization only. +// This information typically comes from special nodes found in quantized +// models, +// see FakeQuantOperator, and is used during quantization to resolve +// actual quantization parameters (see QuantizationParams). +struct MinMax { + double min = 0.; + double max = 0.; +}; + +inline bool operator==(const MinMax& m1, const MinMax& m2) { + return m1.min == m2.min && m1.max == m2.max; +} + +// Fake-quantization operator. This does two things: +// - Annotate its input and output arrays with MinMax information, +// - Arithmetic-wise, this operator rounds incoming activation values +// to the nearest representable value on the scale of 256 +// values from the min to the max value dictated by its MinMax info. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: optional: the 'min' value, if it has not yet been resolved +// to a constant. +// inputs[2]: optional: the 'max' value, if it has not yet been resolved +// to a constant. +// +// TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs. +struct FakeQuantOperator : Operator { + FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {} + std::unique_ptr minmax; +}; + +// Element-wise division operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Div +struct DivOperator : Operator { + DivOperator() : Operator(OperatorType::kDiv) {} +}; + +// Element-wise identity (x->x) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Identity +struct TensorFlowIdentityOperator : Operator { + TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} +}; + +// General matrix multiplication operator. We don't want to support general +// matrix multiplication at inference time, so we resolve it during tooling +// to more specific operator types, namely, FullyConnected. +// +// Inputs: +// inputs[0]: required: the left-hand side matrix +// inputs[1]: required: the right-hand side matrix +// +// TensorFlow equivalent: MatMul +struct TensorFlowMatMulOperator : Operator { + TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {} +}; + +// Padding operator. Pads a tensor with zeros. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the padding array +// +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of +// `input` in that dimension. +// +// TensorFlow equivalent: Pad +struct PadOperator : Operator { + PadOperator() : Operator(OperatorType::kPad) {} + + std::vector left_padding; + std::vector right_padding; +}; + +// Strided slice operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: StridedSlice +struct StridedSliceOperator : Operator { + StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {} + + std::vector start_indices; + std::vector stop_indices; + std::vector strides; + + int begin_mask; + int ellipsis_mask; + int end_mask; + int new_axis_mask; + int shrink_axis_mask; +}; + +// Reshaping operator, reshaping its input array to a two-dimensional shape +// (a "matrix"). This is used in the TensorFlow format, in conjunction with +// MatMul nodes, to implement fully-connected layers. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Reshape --- except that we only support a special case +// here, where the output shape is a matrix (2D) shape. +struct TensorFlowReshapeOperator : Operator { + TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {} + std::vector shape; +}; + +// Removes dimensions of size 1 from the shape of a tensor. +// https://www.tensorflow.org/api_docs/python/tf/squeeze +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Squeeze +struct SqueezeOperator : Operator { + SqueezeOperator() : Operator(OperatorType::kSqueeze) {} + + std::vector squeeze_dims; +}; + +// Element-wise reciprocal-square-root (x^-0.5) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Rsqrt +struct TensorFlowRsqrtOperator : Operator { + TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {} +}; + +// Shape operator. Extracts the shape of the tensor. +// +// Inputs: +// inputs[0]: required: the input array +// +// This operation outputs a 1-D integer tensor representing the shape of +// the input. +// +// TensorFlow equivalent: Shape. We currently assume that the output is int32 +// and not int64. The output type could be stored herein. +struct TensorFlowShapeOperator : Operator { + TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {} +}; + +// Element-wise square-root (x^0.5) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sqrt +struct TensorFlowSqrtOperator : Operator { + TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {} +}; + +// Element-wise square (x*x) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Square +struct TensorFlowSquareOperator : Operator { + TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {} +}; + +// Element-wise subtraction operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Sub +struct SubOperator : Operator { + SubOperator() : Operator(OperatorType::kSub) {} +}; + +// Global sum reduction: computes the sum of all of entries in the input array. +// Thus the output is "0-dimensional": it consists of a single scalar value. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sum --- except that we only support the special case +// of global reduction across all dimensions. +struct TensorFlowSumOperator : Operator { + TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {} +}; + +// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +struct TensorFlowTileOperator : Operator { + TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} +}; + +// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details. +struct SliceOperator : Operator { + SliceOperator() : Operator(OperatorType::kSlice) {} + + std::vector begin; + std::vector size; +}; + +// TensorFlow Split equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +struct TensorFlowSplitOperator : Operator { + TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {} + int num_split = 0; +}; + +// TensorFlow Concat equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Concretely, once the concat dim becomes known, if it is the depth +// dimension then we can change this op into a DepthConcatenation op. +// Otherwise, we hope for some other graph transformation to drop this node. +struct TensorFlowConcatOperator : Operator { + TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {} +}; + +// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Concretely, once the concat dim becomes known, if it is the depth +// dimension then we can change this op into a DepthConcatenation op. +// Otherwise, we hope for some other graph transformation to drop this node. +struct TensorFlowConcatV2Operator : Operator { + TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {} +}; + +// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details. +// +// Inputs: this operator accepts any number >= 1 of inputs. +// inputs[i]: the i-th array to merge. +// +// It is expected that graph transformations will drop all but exactly one +// of the inputs, at which point the Merge node will be equivalent to an +// Identity node forwarding the remaining input. +// +// Note: We do not currently support runtime control flow: we only support +// control flow that can be resolved at tooling time (independently of input +// activations). +struct TensorFlowMergeOperator : Operator { + TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {} +}; + +// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the boolean predicate, given as an array of size 1 +// and of type kBool, will determine which output gets selected. +// +// Outputs: a TensorFlow Switch node always has exactly two outputs. Depending +// on the boolean value that the input predicate resolves to (see note below), +// one or the other of the outputs will be 'selected': the input array will be +// forwarded to the 'selected output' as if by a Identity node, while the other +// output will be discarded, and any graph edge connecting that discarded output +// will be dropped. The rule for selecting outputs is as follows: +// outputs[0] will be selected if the input predicate resolves to 'true'. +// outputs[1] will be selected if the input predicate resolves to 'false'. +// +// Note: We do not currently support runtime control flow: we only support +// control flow that can be resolved at tooling time (independently of input +// activations). +struct TensorFlowSwitchOperator : Operator { + TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {} +}; + +// TensorFlow All equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowAllOperator : Operator { + TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {} +}; + +// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, we just drop Assert nodes. +struct TensorFlowAssertOperator : Operator { + TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {} +}; + +// TensorFlow Less equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowLessOperator : Operator { + TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {} +}; + +// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowLessEqualOperator : Operator { + TensorFlowLessEqualOperator() + : Operator(OperatorType::kTensorFlowLessEqual) {} +}; + +// TensorFlow Less equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowGreaterOperator : Operator { + TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {} +}; + +// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowGreaterEqualOperator : Operator { + TensorFlowGreaterEqualOperator() + : Operator(OperatorType::kTensorFlowGreaterEqual) {} +}; + +// Global max reduction: computes the max of all of entries in the input array. +// Thus the output is "0-dimensional": it consists of a single scalar value. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Max --- except that we only support the special case +// of global reduction across all dimensions. +struct TensorFlowMaxOperator : Operator { + TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {} +}; + +// Global min reduction: computes the min of all of entries in the input array. +// Thus the output is "0-dimensional": it consists of a single scalar value. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Min --- except that we only support the special case +// of global reduction across all dimensions. +struct TensorFlowMinOperator : Operator { + TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {} +}; + +// Element-wise maximum operator. Currently it only supports scalar as +// the second operand. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Maximum +struct TensorFlowMaximumOperator : Operator { + TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {} +}; + +// Element-wise minimum operator. Currently it only supports scalar as +// the second operand. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Minimum +struct TensorFlowMinimumOperator : Operator { + TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {} +}; + +// General TF operation, unsupported by tf.mini. Expected to be dropped by +// graph transformations. +struct TensorFlowUnsupportedOperator : Operator { + TensorFlowUnsupportedOperator() + : Operator(OperatorType::kTensorFlowUnsupported) {} + + // The original TF operation type. Used for diagnostic purposes. + string tensorflow_op; + // A serialized tensorflow::NodeDef string. + string tensorflow_node_def; + // A boolean indicating if the unsupported op should be treated as quantized. + bool quantized = false; + // Output data types + std::vector output_data_types; +}; + +// Softmax activation function. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Softmax +struct SoftmaxOperator : Operator { + SoftmaxOperator() : Operator(OperatorType::kSoftmax) {} + float beta = 0.f; +}; + +// Cast operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Cast +struct CastOperator : Operator { + CastOperator() : Operator(OperatorType::kCast) {} + ArrayDataType src_data_type = ArrayDataType::kNone; + ArrayDataType dst_data_type = ArrayDataType::kNone; +}; + +// Floor operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Floor +struct FloorOperator : Operator { + FloorOperator() : Operator(OperatorType::kFloor) {} +}; + +// Gather operator. It gathers slices from params according to indices. +// Only 1-D indices are supported at the moment. +// +// Inputs: +// inputs[0]: required: the params array +// inputs[1]: required: the indices to gather +// +// TensorFlow equivalent: Gather +struct GatherOperator : Operator { + GatherOperator() : Operator(OperatorType::kGather) {} + int input_rank; +}; + +// ResizeBilinear operator. It resizes input images with bilinear interpolation. +// It does not support align_corners at the moment. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the new image size +// +// TensorFlow equivalent: ResizeBilinear +struct ResizeBilinearOperator : Operator { + ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {} +}; + +// SpaceToBatchND operator. It divides spatial dimensions into a grid of +// blocks and interleaves these blocks with the batch dimension. Currently, +// only 2-d blocks are supported. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the block shape +// inputs[2]: required: the paddings +// +// TensorFlow equivalent: SpaceToBatchND +struct SpaceToBatchNDOperator : Operator { + SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {} +}; + +// BatchToSpaceND operator. Rearranges data from batch into blocks of +// spatial data. Currently, only 2-d blocks are supported. Cropping is not +// supported, either, and the crops array should be all zero. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the block shape +// inputs[2]: required: the crops +// +// TensorFlow equivalent: BatchToSpaceND +struct BatchToSpaceNDOperator : Operator { + BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {} +}; + +// Mean operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Mean +struct MeanOperator : Operator { + MeanOperator() : Operator(OperatorType::kMean) {} + + std::vector reduction_indices; +}; + +// Svdf operator: +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: weights_feature +// inputs[2]: required: weights_time +// inputs[3]: optional: bias +struct SvdfOperator : Operator { + SvdfOperator() : Operator(OperatorType::kSvdf) {} + int rank; +}; + +// Alloc's are used for transient arrays only. An Alloc specifies which interval +// of the "transient_data" workspace buffer passed to inference functions, is to +// be used for the transient array at hand. The 'start' and 'end' values are +// offsets from the start of the workspace buffer, expressed in bytes. +struct Alloc { + int start = 0; + int end = 0; +}; + +inline bool operator<(const Alloc& a, const Alloc& b) { + return a.start < b.start; +} + +// Quantization parameters, determining the mapping of quantized values +// to real values (i.e. determining how quantized values are mathematically +// interpreted). +// +// The correspondence is as follows: +// +// real_value = scale * (quantized_value - zero_point); +// +// In other words, zero_point designates which quantized value corresponds to +// the real 0 value, and scale designates the difference between the real values +// corresponding to consecutive quantized values differing by 1. +struct QuantizationParams { + int32 zero_point = 0; + double scale = 0.; +}; + +class Shape { + public: + // For Shape, we stick to half-way encapsulation for now: + // we hide the raw dims_ member, but expose it raw by accessors + // because from some brainstorming, it's not at all easy to + // anticipate which flavor of more hermetic encapsulation would + // actually buy us future-proof-ness without being needlessly + // cumbersome. + Shape() {} + Shape(std::initializer_list dim_list) : dims_(dim_list) {} + + void ReplaceDims(std::initializer_list dim_list) { + dims_ = std::vector(dim_list); + } + + const std::vector& dims() const { return dims_; } + std::vector* mutable_dims() { return &dims_; } + const int dimensions_count() const { return dims_.size(); } + + // We still have that one convenience accessor to avoid + // the awkward double bracket issue: shape.dims()[i]. + int dims(int i) const { return dims_[i]; } + + bool operator==(const Shape& comp) const { + return (this->dims_ == comp.dims()); + } + + bool operator!=(const Shape& comp) const { return !((*this) == comp); } + + private: + std::vector dims_; +}; + +// Array represents an array (either a constant parameter array or an +// activations array) in a Model. +struct Array { + template + const Buffer& GetBuffer() const { + DCHECK(buffer); + DCHECK(buffer->type == A); + return *static_cast*>(buffer.get()); + } + template + Buffer& GetMutableBuffer() { + if (!buffer) { + Buffer* ptr = new Buffer; + buffer = std::unique_ptr(ptr); + } + DCHECK(buffer); + DCHECK(buffer->type == A); + return *static_cast*>(buffer.get()); + } + Alloc& GetOrCreateAlloc() { + if (!alloc) { + alloc = std::unique_ptr(new Alloc); + } + return *alloc; + } + MinMax& GetOrCreateMinMax() { + if (!minmax) { + minmax = std::unique_ptr(new MinMax); + } + return *minmax; + } + MinMax& GetMinMax() const { + DCHECK(minmax); + return *minmax; + } + QuantizationParams& GetOrCreateQuantizationParams() { + if (!quantization_params) { + quantization_params = + std::unique_ptr(new QuantizationParams); + } + return *quantization_params; + } + QuantizationParams& GetQuantizationParams() const { + DCHECK(quantization_params); + return *quantization_params; + } + + // The data type of the actual elements of this array, that is: + // - If there is a buffer (see 'buffer' member), it must be of the same + // type. + // - If there is no buffer, meaning that this is a runtime (i.e. activations) + // array, then this specifies the type of elements that there will be + // at runtime. + // + // Note that this only specifies the storage type of elements; this does + // not specify whether these are to be treated as 'real' or 'quantized' + // values. + // That is decided by whether the 'quantization_params' member is null. + ArrayDataType data_type = ArrayDataType::kNone; + // The final value that data_type should have at the end of graph + // transformations + ArrayDataType final_data_type = ArrayDataType::kNone; + // The dimensions of this array --- this specifies both sizes and strides + // (the storage layout). + // + // Issues with shape handling that remain include: + // - No way to distinguish between 0-dimensional dims and missing dims. + // - No way to describe dims that may be runtime-variable. + // - Addressing of dims by integer index differs in different graph formats + // (TensorFlow vs. other frameworks vs. what we have informally grown + // within toco). + // This is currently quite messy; see ReorderAxesOperator which is how we + // bridge some of these discrepancies at the moment. This is overdue for + // a redesign; I'm thinking that it would be nice to have more flexible + // dims that allow mapping 1:1, cleanly, dims as they are in various + // formats, + // then explicitly convert between different conventions. + + // Proto-style accessors + bool has_shape() const { return array_shape != nullptr; } + const Shape& shape() const { + CHECK(has_shape()); + return *array_shape; + } + Shape* mutable_shape() { + if (!array_shape) { + array_shape.reset(new Shape); + } + return array_shape.get(); + } + void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; } + void clear_shape() { array_shape = nullptr; } + + // The constant buffer backing this array. This is non-null if and only if + // this is a constant parameter array. Conversely, this is null for + // activations arrays. + // + // Note that this buffer is pure storage. In the case of quantized values, + // it only stores the quantized values, it does not know by itself about the + // quantization parameters necessary to interprete these values, that is + // in the separate 'quantization_params' field. In fact, this 'buffer' field + // does no even know whether values are quantized. It only has a data_type, + // which must equal the 'data_type' member here, and which only describes + // the storage type of element, does not tell whether they are quantized i.e. + // whether they are to be interpreted with quantization_params. + std::unique_ptr buffer; + // Only for activation arrays (i.e. when 'buffer' is null). + // Only for code generation. + // + // Describes the allocation of this array within the workspace buffer + // allocated + // for all transient arrays. + std::unique_ptr alloc; + // Describes the [min, max] range of values + // to be assumed when determining quantization_params. + // + // Only used for quantization. In fact, only used for determining + // quantization_params. + // + // Used for both constant arrays (those having a 'buffer') and non-constant + // arrays (activations). Indeed, it is important to use the same min-max range + // as was used during training, even if that min-max range is slightly wrong + // w.r.t. actual buffer elements. Doing otherwise would defeat the point of + // re-training for quantization. + std::unique_ptr minmax; + // Quantization parameters. The non-null-ness of this pointer is what + // defines whether this array is quantized or not. + // + // If this is non-null, then these quantization parameters are to be used + // to assign a meaning as real numbers to the elements of this array. + std::unique_ptr quantization_params; + + private: + std::unique_ptr array_shape; +}; + +// Our Model struct, represents an entire model (our "top-level" struct). +// Owns everything. +struct Model { + Array& GetArray(const string& name) const { + DCHECK(arrays.count(name)); + return *arrays.at(name); + } + Array& GetOrCreateArray(const string& name) { + if (!arrays.count(name)) { + Array* ptr = new Array; + arrays[name] = std::unique_ptr(ptr); + } + Array& result = GetArray(name); + return result; + } + + // The list of operators. Notice how it's a list of unique_ptr's, implying + // that the Model is what owns Operator's and keeps them alive. + std::vector> operators; + // The associative array mapping names to Array's. + // Notice how it's a container of unique_ptr's, implying + // that the Model is what owns Array's and keeps them alive. + // The Operator's refer to these Array's by their name strings, not by their + // addresses. See Operator::inputs, Operator::outputs. + std::unordered_map> arrays; + // Generic flags, a place where we combine information passed to us via + // command-line parameters (e.g. --input_width=N) with information that + // we may or may not find in the input model file. + ModelFlags flags; + // For code-generation only: required size of the transient_data buffer + std::size_t transient_data_size = 0; + // For code-generation only: required alignment of the transient_data buffer + std::size_t transient_data_alignment = 0; +}; +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc new file mode 100644 index 0000000000..699c95753f --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -0,0 +1,374 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +// "batch" flag only exists internally +#ifdef PLATFORM_GOOGLE +#include "base/commandlineflags.h" +#endif + +namespace toco { + +bool ParseModelFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedModelFlags* parsed_model_flags_ptr) { + ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr; + using tensorflow::Flag; + std::vector flags = { + Flag("input_array", parsed_flags.input_array.bind(), + parsed_flags.input_array.default_value(), + "Name of the input array. If not specified, will try to read " + "that information from the input file."), + Flag("input_arrays", parsed_flags.input_arrays.bind(), + parsed_flags.input_arrays.default_value(), + "Names of the output arrays, comma-separated. If not specified, " + "will try to read that information from the input file."), + Flag("output_array", parsed_flags.output_array.bind(), + parsed_flags.output_array.default_value(), + "Name of the output array, when specifying a unique output array. " + "If not specified, will try to read that information from the " + "input file."), + Flag("output_arrays", parsed_flags.output_arrays.bind(), + parsed_flags.output_arrays.default_value(), + "Names of the output arrays, comma-separated. " + "If not specified, will try to read " + "that information from the input file."), + Flag("input_shape", parsed_flags.input_shape.bind(), + parsed_flags.output_arrays.default_value(), + "Input array shape. For many models the shape takes the form " + "batch size, input array height, input array width, input array " + "depth."), + Flag("input_shapes", parsed_flags.input_shapes.bind(), + parsed_flags.input_shapes.default_value(), + "Shapes corresponding to --input_arrays, colon-separated. For " + "many models each shape takes the form batch size, input array " + "height, input array width, input array depth."), + Flag("mean_value", parsed_flags.mean_value.bind(), + parsed_flags.mean_value.default_value(), + "mean_value parameter for image models, used to compute input " + "activations from input pixel data."), + Flag("mean_values", parsed_flags.mean_values.bind(), + parsed_flags.mean_values.default_value(), + "mean_values parameter for image models, comma-separated list of " + "doubles, used to compute input activations from input pixel " + "data. Each entry in the list should match an entry in " + "--input_arrays."), + Flag("std_value", parsed_flags.std_value.bind(), + parsed_flags.std_value.default_value(), + "std_value parameter for image models, used to compute input " + "activations from input pixel data."), + Flag("std_values", parsed_flags.std_values.bind(), + parsed_flags.std_values.default_value(), + "std_value parameter for image models, comma-separated list of " + "doubles, used to compute input activations from input pixel " + "data. Each entry in the list should match an entry in " + "--input_arrays."), + Flag("variable_batch", parsed_flags.variable_batch.bind(), + parsed_flags.variable_batch.default_value(), + "If true, the model accepts an arbitrary batch size. Mutually " + "exclusive " + "with the 'batch' field: at most one of these two fields can be " + "set."), + Flag( + "drop_control_dependency", + parsed_flags.drop_control_dependency.bind(), + parsed_flags.drop_control_dependency.default_value(), + "If true, ignore control dependency requirements in input TensorFlow " + "GraphDef. Otherwise an error will be raised upon control dependency " + "inputs."), + Flag("rnn_states", parsed_flags.rnn_states.bind(), + parsed_flags.rnn_states.default_value(), ""), + Flag("model_checks", parsed_flags.model_checks.bind(), + parsed_flags.model_checks.default_value(), + "A list of model checks to be applied to verify the form of the " + "model. Applied after the graph transformations after import."), + Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(), + parsed_flags.graphviz_first_array.default_value(), + "If set, defines the start of the sub-graph to be dumped to " + "GraphViz."), + Flag( + "graphviz_last_array", parsed_flags.graphviz_last_array.bind(), + parsed_flags.graphviz_last_array.default_value(), + "If set, defines the end of the sub-graph to be dumped to GraphViz."), + Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(), + parsed_flags.dump_graphviz.default_value(), + "Dump graphviz during LogDump call. If string is non-empty then " + "it defines path to dump, otherwise will skip dumping."), + Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(), + parsed_flags.dump_graphviz_video.default_value(), + "If true, will dump graphviz at each " + "graph transformation, which may be used to generate a video."), + }; + bool asked_for_help = + *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); + if (asked_for_help) { + *msg += tensorflow::Flags::Usage(argv[0], flags); + return false; + } else { + if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; + } + auto& dump_options = *GraphVizDumpOptions::singleton(); + dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value(); + dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value(); + dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); + dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); + + return true; +} + +void ReadModelFlagsFromCommandLineFlags( + const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { + toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + +// "batch" flag only exists internally +#ifdef PLATFORM_GOOGLE + CHECK(!((base::SpecifiedOnCommandLine("batch") && + parsed_model_flags.variable_batch.specified()))) + << "The --batch and --variable_batch flags are mutually exclusive."; +#endif + CHECK(!(parsed_model_flags.output_array.specified() && + parsed_model_flags.output_arrays.specified())) + << "The --output_array and --vs flags are mutually exclusive."; + + if (parsed_model_flags.output_array.specified()) { + model_flags->add_output_arrays(parsed_model_flags.output_array.value()); + } + + if (parsed_model_flags.output_arrays.specified()) { + std::vector output_arrays = + absl::StrSplit(parsed_model_flags.output_arrays.value(), ','); + for (const string& output_array : output_arrays) { + model_flags->add_output_arrays(output_array); + } + } + + const bool uses_single_input_flags = + parsed_model_flags.input_array.specified() || + parsed_model_flags.mean_value.specified() || + parsed_model_flags.std_value.specified() || + parsed_model_flags.input_shape.specified(); + + const bool uses_multi_input_flags = + parsed_model_flags.input_arrays.specified() || + parsed_model_flags.mean_values.specified() || + parsed_model_flags.std_values.specified() || + parsed_model_flags.input_shapes.specified(); + + QCHECK(!(uses_single_input_flags && uses_multi_input_flags)) + << "Use either the singular-form input flags (--input_array, " + "--input_shape, --mean_value, --std_value) or the plural form input " + "flags (--input_arrays, --input_shapes, --mean_values, --std_values), " + "but not both forms within the same command line."; + + if (parsed_model_flags.input_array.specified()) { + QCHECK(uses_single_input_flags); + model_flags->add_input_arrays()->set_name( + parsed_model_flags.input_array.value()); + } + if (parsed_model_flags.input_arrays.specified()) { + QCHECK(uses_multi_input_flags); + for (const auto& input_array : + absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) { + model_flags->add_input_arrays()->set_name(string(input_array)); + } + } + if (parsed_model_flags.mean_value.specified()) { + QCHECK(uses_single_input_flags); + model_flags->mutable_input_arrays(0)->set_mean_value( + parsed_model_flags.mean_value.value()); + } + if (parsed_model_flags.mean_values.specified()) { + QCHECK(uses_multi_input_flags); + std::vector mean_values = + absl::StrSplit(parsed_model_flags.mean_values.value(), ','); + QCHECK(mean_values.size() == model_flags->input_arrays_size()); + for (int i = 0; i < mean_values.size(); ++i) { + char* last = nullptr; + model_flags->mutable_input_arrays(i)->set_mean_value( + strtod(mean_values[i].data(), &last)); + CHECK(last != mean_values[i].data()); + } + } + if (parsed_model_flags.std_value.specified()) { + QCHECK(uses_single_input_flags); + model_flags->mutable_input_arrays(0)->set_std_value( + parsed_model_flags.std_value.value()); + } + if (parsed_model_flags.std_values.specified()) { + QCHECK(uses_multi_input_flags); + std::vector std_values = + absl::StrSplit(parsed_model_flags.std_values.value(), ','); + QCHECK(std_values.size() == model_flags->input_arrays_size()); + for (int i = 0; i < std_values.size(); ++i) { + char* last = nullptr; + model_flags->mutable_input_arrays(i)->set_std_value( + strtod(std_values[i].data(), &last)); + CHECK(last != std_values[i].data()); + } + } + if (parsed_model_flags.input_shape.specified()) { + QCHECK(uses_single_input_flags); + if (model_flags->input_arrays().empty()) { + model_flags->add_input_arrays(); + } + auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape(); + shape->Clear(); + const IntList& list = parsed_model_flags.input_shape.value(); + for (auto& dim : list.elements) { + shape->Add(dim); + } + } + if (parsed_model_flags.input_shapes.specified()) { + QCHECK(uses_multi_input_flags); + std::vector input_shapes = + absl::StrSplit(parsed_model_flags.input_shapes.value(), ':'); + QCHECK(input_shapes.size() == model_flags->input_arrays_size()); + for (int i = 0; i < input_shapes.size(); ++i) { + auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); + shape->Clear(); + if (input_shapes[i].empty()) { + // empty i.e. 0-dimensional input shape. + // Unfortunately, the current toco::InputArray + // proto does not allow to distinguish between a known 0-D shape, + // and an unknown shape. Indeed, shape is currently a plain array, + // and it being empty means unknown shape. So here, we import a + // 0-D shape as a 1-D shape of size. + // TODO(benoitjacob): fix toco::InputArray to allow 0-D shape, + // probably by making shape an optional message, + // encapsulating the array. + shape->Add(1); + } else { + for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { + int size; + CHECK(absl::SimpleAtoi(dim_str, &size)) + << "Failed to parse input_shape: " << input_shapes[i]; + shape->Add(size); + } + } + } + } + +#define READ_MODEL_FLAG(name) \ + do { \ + if (parsed_model_flags.name.specified()) { \ + model_flags->set_##name(parsed_model_flags.name.value()); \ + } \ + } while (false) + + READ_MODEL_FLAG(variable_batch); + READ_MODEL_FLAG(drop_control_dependency); + +#undef READ_MODEL_FLAG + + for (const auto& element : parsed_model_flags.rnn_states.value().elements) { + auto* rnn_state_proto = model_flags->add_rnn_states(); + for (const auto& kv_pair : element) { + const string& key = kv_pair.first; + const string& value = kv_pair.second; + if (key == "state_array") { + rnn_state_proto->set_state_array(value); + } else if (key == "back_edge_source_array") { + rnn_state_proto->set_back_edge_source_array(value); + } else if (key == "size") { + int32 size = 0; + CHECK(absl::SimpleAtoi(value, &size)); + CHECK_GT(size, 0); + rnn_state_proto->set_size(size); + } else if (key == "manually_create") { + CHECK_EQ(absl::AsciiStrToLower(value), "true"); + rnn_state_proto->set_manually_create(true); + } else { + LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states"; + } + } + CHECK(rnn_state_proto->has_state_array() && + rnn_state_proto->has_back_edge_source_array() && + rnn_state_proto->has_size()) + << "--rnn_states must include state_array, back_edge_source_array and " + "size."; + } + + for (const auto& element : parsed_model_flags.model_checks.value().elements) { + auto* model_check_proto = model_flags->add_model_checks(); + for (const auto& kv_pair : element) { + const string& key = kv_pair.first; + const string& value = kv_pair.second; + if (key == "count_type") { + model_check_proto->set_count_type(value); + } else if (key == "count_min") { + int32 count = 0; + CHECK(absl::SimpleAtoi(value, &count)); + CHECK_GE(count, -1); + model_check_proto->set_count_min(count); + } else if (key == "count_max") { + int32 count = 0; + CHECK(absl::SimpleAtoi(value, &count)); + CHECK_GE(count, -1); + model_check_proto->set_count_max(count); + } else { + LOG(FATAL) << "Unknown key '" << key << "' in --model_checks"; + } + } + } +} + +ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) { + static auto* flags = [must_already_exist]() { + if (must_already_exist) { + fprintf(stderr, __FILE__ + ":" + "GlobalParsedModelFlags() used without initialization\n"); + fflush(stderr); + abort(); + } + return new toco::ParsedModelFlags; + }(); + return flags; +} + +ParsedModelFlags* GlobalParsedModelFlags() { + return UncheckedGlobalParsedModelFlags(true); +} + +void ParseModelFlagsOrDie(int* argc, char* argv[]) { + // TODO(aselle): in the future allow Google version to use + // flags, and only use this mechanism for open source + auto* flags = UncheckedGlobalParsedModelFlags(false); + string msg; + bool model_success = + toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags); + if (!model_success || !msg.empty()) { + // Log in non-standard way since this happens pre InitGoogle. + fprintf(stderr, "%s", msg.c_str()); + fflush(stderr); + abort(); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/contrib/lite/toco/model_cmdline_flags.h new file mode 100644 index 0000000000..dfa3d3c1ef --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" + +namespace toco { +// Parse and remove arguments for models (in toco). Returns true if parsing +// is successful. msg has the usage string if there was an error or +// "--help" was specified +bool ParseModelFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedModelFlags* parsed_model_flags_ptr); +// Populate the ModelFlags proto with model data. +void ReadModelFlagsFromCommandLineFlags( + const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags); +// Parse the global model flags to a static +void ParseModelFlagsOrDie(int* argc, char* argv[]); +// Get the global parsed model flags +ParsedModelFlags* GlobalParsedModelFlags(); + +} // namespace toco + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto new file mode 100644 index 0000000000..743e08b16f --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -0,0 +1,119 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto2"; + +package toco; + +// Next ID to USE: 5. +message InputArray { + // Name of the input arrays, i.e. the arrays from which input activations + // will be read. + optional string name = 1; + + // Shape of the input. For many applications the dimensions are {batch, + // height, width, depth}. Often the batch is left "unspecified" by providing + // a value of -1. + // + // The last dimension is typically called 'depth' or 'channels'. For example, + // for an image model taking RGB images as input, this would have the value 3. + repeated int32 shape = 2; + + // mean_value and std_value parameters control the interpretation of raw input + // activation values (elements of the input array) as real numbers. The + // mapping is given by: + // + // real_value = (raw_input_value - mean_value) / std_value + // + // In particular, the defaults (mean_value=0, std_value=1) yield + // real_value = raw_input_value. Often, non-default values are used in image + // models. For example, an image model taking uint8 image channel values as + // its raw inputs, in [0, 255] range, may use mean_value=128, std_value=128 to + // map them into the interval [-1, 1). + // + // Note: this matches exactly the meaning of mean_value and std_value in + // (TensorFlow via LegacyFedInput). + optional float mean_value = 3; + optional float std_value = 4 [default = 1.]; +} + +// ModelFlags encodes properties of a model that, depending on the file +// format, may or may not be recorded in the model file. The purpose of +// representing these properties in ModelFlags is to allow passing them +// separately from the input model file, for instance as command-line +// parameters, so that we can offer a single uniform interface that can +// handle files from different input formats. +// +// For each of these properties, and each supported file format, we +// detail in comments below whether the property exists in the given file +// format. +// +// Obsolete flags that have been removed: +// optional int32 input_depth = 3; +// optional int32 input_width = 4; +// optional int32 input_height = 5; +// optional int32 batch = 6 [ default = 1]; +// optional float mean_value = 7; +// optional float std_value = 8 [default = 1.]; +// optional int32 input_dims = 11 [ default = 4]; +// repeated int32 input_shape = 13; +// +// Next ID to USE: 16. +message ModelFlags { + // Information about the input arrays, i.e. the arrays from which input + // activations will be read. + repeated InputArray input_arrays = 1; + + // Name of the output arrays, i.e. the arrays into which output activations + // will be written. + repeated string output_arrays = 2; + + // If true, the model accepts an arbitrary batch size. Mutually exclusive with + // the 'batch' field: at most one of these two fields can be set. + optional bool variable_batch = 10; + + message RnnState { + optional string state_array = 1; + optional string back_edge_source_array = 2; + optional int32 size = 3; + // TODO(benoitjacob): manually_create is a temporary hack: + // due to discrepancies between the current toco dims tracking and + // TensorFlow shapes, for some models we need to manually create RNN state + // arrays with a specified shape. + // Maybe we should actually implement back-edges as operators of their own, + // which would remove the need for much special-casing, including here, + // we could probably consistently let PropagateFixedSizes handle state + // arrays. + optional bool manually_create = 4; + } + repeated RnnState rnn_states = 12; + + // Checks applied to the model, typically after toco's comprehensive + // graph transformations. + // Next ID to USE: 4. + message ModelCheck { + // Use the name of a type of operator to check its counts. + // Use "Total" for overall operator counts. + // Use "Arrays" for overall array counts. + optional string count_type = 1 [default = "None"]; + // A count of zero is a meaningful check, so negative used to mean disable. + optional int32 count_min = 2 [default = -1]; + // If count_max < count_min, then count_min is only allowed value. + optional int32 count_max = 3 [default = -1]; + } + repeated ModelCheck model_checks = 14; + + // If true, ignore control dependency requirements in input TensorFlow + // GraphDef. Otherwise an error will be raised upon control dependency inputs. + optional bool drop_control_dependency = 15; +} diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD new file mode 100644 index 0000000000..92246a8aed --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -0,0 +1,76 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +cc_library( + name = "toco_python_api", + srcs = ["toco_python_api.cc"], + hdrs = ["toco_python_api.h"], + deps = [ + "//tensorflow/contrib/lite/toco:model_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_port", + "//tensorflow/contrib/lite/toco:toco_tooling", + "//tensorflow/core:lib", + "//util/python:python_headers", + ], +) + +tf_py_wrap_cc( + name = "tensorflow_wrap_toco", + srcs = ["toco.i"], + deps = [ + ":toco_python_api", + "//tensorflow/contrib/lite/toco:model_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//util/python:python_headers", + "@com_google_absl//absl/strings", + ], +) + +py_binary( + name = "toco_from_protos", + srcs = ["toco_from_protos.py"], + srcs_version = "PY2AND3", + deps = [ + ":tensorflow_wrap_toco", + "//tensorflow/python:platform", + ], +) + +py_binary( + name = "toco_wrapper", + srcs = ["toco_wrapper.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "toco_from_protos_test", + srcs = ["toco_from_protos_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/toco:model_flags_proto_py", + "//tensorflow/contrib/lite/toco:toco_flags_proto_py", + ], + data = [ + ":toco_from_protos", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/contrib/lite/toco/python/toco.i new file mode 100644 index 0000000000..3787cba4a3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco.i @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "std_string.i" + +%{ +#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" +%} + +namespace toco { + +// Convert a model represented in `input_contents`. `model_flags_proto` +// describes model parameters. `toco_flags_proto` describes conversion +// parameters (see relevant .protos for more information). Returns a string +// representing the contents of the converted model. +PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, + PyObject* toco_flags_proto_txt_raw, + PyObject* input_contents_txt_raw); + +} // namespace toco \ No newline at end of file diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos.py b/tensorflow/contrib/lite/toco/python/toco_from_protos.py new file mode 100644 index 0000000000..c0b032083b --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_from_protos.py @@ -0,0 +1,63 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python console command to invoke TOCO from serialized protos.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco +from tensorflow.python.platform import app + +FLAGS = None + + +def execute(unused_args): + model_str = open(FLAGS.model_proto_file, "rb").read() + toco_str = open(FLAGS.toco_proto_file, "rb").read() + input_str = open(FLAGS.model_input_file, "rb").read() + + output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str) + open(FLAGS.model_output_file, "wb").write(output_str) + sys.exit(0) + + +def main(): + global FLAGS + parser = argparse.ArgumentParser( + description="Invoke toco using protos as input.") + parser.add_argument( + "model_proto_file", + type=str, + help="File containing serialized proto that describes the model.") + parser.add_argument( + "toco_proto_file", + type=str, + help="File containing serialized proto describing how TOCO should run.") + parser.add_argument( + "model_input_file", type=str, help="Input model is read from this file.") + parser.add_argument( + "model_output_file", + type=str, + help="Result of applying TOCO conversion is written here.") + + FLAGS, unparsed = parser.parse_known_args() + + app.run(main=execute, argv=[sys.argv[0]] + unparsed) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py new file mode 100644 index 0000000000..2a593beeca --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import tensorflow as tf +from tensorflow.contrib.lite.toco import model_flags_pb2 +from tensorflow.contrib.lite.toco import toco_flags_pb2 +from tensorflow.python.platform import googletest +from tensorflow.python.platform import resource_loader + + +def TensorName(x): + """Get the canonical (non foo:0 name).""" + return x.name.split(":")[0] + + +class TocoFromProtosTest(googletest.TestCase): + + def _run(self, sess, in_tensor, out_tensor, should_succeed): + """Use toco binary to check conversion from graphdef to tflite. + + Args: + sess: Active TensorFlow session containing graph. + in_tensor: TensorFlow tensor to use as input. + out_tensor: TensorFlow tensor to use as output. + should_succeed: Whether this is a valid conversion. + """ + # Build all protos and extract graphdef + graph_def = sess.graph_def + toco_flags = toco_flags_pb2.TocoFlags() + toco_flags.input_format = toco_flags_pb2.TENSORFLOW_GRAPHDEF + toco_flags.output_format = toco_flags_pb2.TFLITE + toco_flags.input_types.append(toco_flags_pb2.FLOAT) + toco_flags.inference_type = toco_flags_pb2.FLOAT + model_flags = model_flags_pb2.ModelFlags() + input_array = model_flags.input_arrays.add() + input_array.name = TensorName(in_tensor) + input_array.shape.extend(map(int, in_tensor.get_shape())) + model_flags.output_arrays.append(TensorName(out_tensor)) + # Shell out to run toco (in case it crashes) + with tempfile.NamedTemporaryFile() as fp_toco, \ + tempfile.NamedTemporaryFile() as fp_model, \ + tempfile.NamedTemporaryFile() as fp_input, \ + tempfile.NamedTemporaryFile() as fp_output: + fp_model.write(model_flags.SerializeToString()) + fp_toco.write(toco_flags.SerializeToString()) + fp_input.write(graph_def.SerializeToString()) + fp_model.flush() + fp_toco.flush() + fp_input.flush() + tflite_bin = resource_loader.get_path_to_datafile("toco_from_protos") + cmdline = " ".join([ + tflite_bin, fp_model.name, fp_toco.name, fp_input.name, fp_output.name + ]) + exitcode = os.system(cmdline) + if exitcode == 0: + stuff = fp_output.read() + self.assertEqual(stuff is not None, should_succeed) + else: + self.assertFalse(should_succeed) + + def test_toco(self): + """Run a couple of TensorFlow graphs against TOCO through the python bin.""" + with tf.Session() as sess: + img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) + val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) + out = tf.identity(val, name="out") + out2 = tf.sin(val, name="out2") + # This is a valid mdoel + self._run(sess, img, out, True) + # This uses an invalid function. + # TODO(aselle): Check to make sure a warning is included. + self._run(sess, img, out2, True) + # This is an identity graph, which doesn't work + self._run(sess, img, img, False) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc new file mode 100644 index 0000000000..8a5e483f3f --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_tooling.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" + +namespace toco { + +#if PY_MAJOR_VERSION >= 3 +#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize +#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize +#else +#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize +#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize +#endif + +// NOTE(aselle): We are using raw PyObject's here because we want to make +// sure we input and output bytes rather than unicode strings for Python3. +PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, + PyObject* toco_flags_proto_txt_raw, + PyObject* input_contents_txt_raw) { + // Use Python C API to validate and convert arguments. In py3 (bytes), + // in py2 (str). + auto ConvertArg = [&](PyObject* obj, bool* error) { + char* buf; + Py_ssize_t len; + if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) { + *error = true; + return std::string(); + } else { + *error = false; + return std::string(buf, len); + } + }; + + bool error; + std::string model_flags_proto_txt = + ConvertArg(model_flags_proto_txt_raw, &error); + if (error) return nullptr; + std::string toco_flags_proto_txt = + ConvertArg(toco_flags_proto_txt_raw, &error); + if (error) return nullptr; + std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); + if (error) return nullptr; + + // Use toco to produce new outputs + toco::ModelFlags model_flags; + if (!model_flags.ParseFromString(model_flags_proto_txt)) { + LOG(FATAL) << "Model proto failed to parse." << std::endl; + } + toco::TocoFlags toco_flags; + if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { + LOG(FATAL) << "Toco proto failed to parse." << std::endl; + } + std::unique_ptr model = + toco::Import(toco_flags, model_flags, input_contents_txt); + toco::Transform(toco_flags, model.get()); + string output_file_contents_txt; + Export(toco_flags, *model, &output_file_contents_txt); + + // Convert arguments back to byte (py3) or str (py2) + return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), + output_file_contents_txt.size()); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h new file mode 100644 index 0000000000..dc378353f7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ +#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ + +#include +#include + +namespace toco { + +// Convert a model represented in `input_contents`. `model_flags_proto` +// describes model parameters. `toco_flags_proto` describes conversion +// parameters (see relevant .protos for more information). Returns a string +// representing the contents of the converted model. +PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, + PyObject* toco_flags_proto_txt_raw, + PyObject* input_contents_txt_raw); + +} // namespace toco + +#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py new file mode 100644 index 0000000000..e39b5f22c7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_wrapper.py @@ -0,0 +1,35 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for runninmg toco binary embedded in pip site-package. + +NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. +It can only install Python "console-scripts." This will work as a console +script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import tensorflow as tf + + +def main(): + # Pip installs the binary in aux-bin off of main site-package install. + # Just find it and exec, passing all arguments in the process. + # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. + binary = os.path.join(tf.__path__[0], 'aux-bin/toco') + os.execvp(binary, sys.argv) diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/contrib/lite/toco/runtime/common.h new file mode 100644 index 0000000000..bd55544f57 --- /dev/null +++ b/tensorflow/contrib/lite/toco/runtime/common.h @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#include "tensorflow/contrib/lite/kernels/internal/common.h" + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h new file mode 100644 index 0000000000..df63b2d59e --- /dev/null +++ b/tensorflow/contrib/lite/toco/runtime/types.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace toco { + +// TODO(ahentz): These are just stopgaps for now, untils we move all +// the code over to tflite. +using tflite::Dims; +using tflite::FusedActivationFunctionType; +using tflite::RequiredBufferSizeForDims; + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD new file mode 100644 index 0000000000..0c1a1141fc --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD @@ -0,0 +1,102 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "cluster_utils", + srcs = [ + "cluster_utils.cc", + ], + hdrs = [ + "cluster_utils.h", + ], + deps = [ + "//tensorflow/contrib/lite/toco:toco_port", + ], +) + +cc_library( + name = "cluster", + srcs = [ + "cluster.cc", + ], + hdrs = [ + "cluster.h", + ], + deps = [ + ":cluster_utils", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "resolve_svdf", + srcs = [ + "resolve_svdf.cc", + ], + hdrs = [ + "resolve_svdf.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":cluster", + ":cluster_utils", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:toco_port", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "resolve_svdf_test", + srcs = ["resolve_svdf_test.cc"], + deps = [ + ":cluster", + ":cluster_utils", + ":resolve_cluster", + ":resolve_svdf", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "resolve_cluster", + srcs = [ + "resolve_cluster.cc", + ], + hdrs = [ + "resolve_cluster.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":cluster", + ":cluster_utils", + ":resolve_svdf", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:protos_all_cc", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc new file mode 100644 index 0000000000..98a130ea39 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" + +namespace toco { + +void Cluster::SetGraphDefInfo(const tensorflow::GraphDef* graph_def) { + graph_def_ = graph_def; + for (const tensorflow::NodeDef& node : graph_def_->node()) { + if (StrContains(node.name(), name_)) { + nodes_.push_back(&node); + } + } +} + +bool Cluster::FindClusterInputsAndOutputs() { + // For every node N in the graph: + // If N belongs to this cluster C, then each of N's inputs that are not part + // of C are then inputs of C. + // If N does not belong to cluster C, then each of N's inputs that belong to C + // are then outputs of C. + for (const tensorflow::NodeDef& node : graph_def_->node()) { + if (StrContains(node.name(), name_)) { + for (int i = 0; i < node.input_size(); i++) { + if (!StrContains(node.input(i), name_)) { + inputs_.push_back(node.input(i)); + } + } + } else { + for (int i = 0; i < node.input_size(); i++) { + if (StrContains(node.input(i), name_)) { + outputs_.push_back(node.input(i)); + } + } + } + } + return (!inputs_.empty()) && (!outputs_.empty()); +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h new file mode 100644 index 0000000000..18ff73ac39 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +// The base class for Cluster. A cluster is group of nodes all related to each +// other because their name match a given "pattern", which shows they all belong +// to a composite op supported in TFLite. The nodes in a cluster will be +// collapsed into a single composite op node plus a series of constant nodes +// holding the input parameters to that node. The nodes in a cluster are assumed +// to be using the same device. By changing the "pattern" we can have different +// subclasses of the base Cluster class. +class Cluster { + public: + virtual ~Cluster() {} + + virtual void CreateNodes() = 0; + + // Save the following info from the original GraphDef this cluster is from: + // 1- a pointer to the GraphDef + // 2- All the nodes in GraphDef which belong to this cluster. + void SetGraphDefInfo(const tensorflow::GraphDef* graph_def); + + const string& GetName() const { return name_; } + + const std::vector>& GetNewNodes() const { + return new_nodes_; + } + + const std::vector& GetNodes() { return nodes_; } + + void SetName(const string& name) { name_ = name; } + + void SetDevice(const string& device) { device_ = device; } + + // Find the input(s) and output(s) of this Cluster. + bool FindClusterInputsAndOutputs(); + + protected: + string name_; + string device_; + std::vector inputs_; + std::vector outputs_; + + // Used to hold the pointers to nodes which are in this cluster. These nodes + // are pointing to the nodes in graph_def_. + std::vector nodes_; + + // Used to cache the newly generated nodes: like the nodes created by + // collapsing Const nodes, or the nodes which is used to show the composite + // op. + std::vector> new_nodes_; + + const tensorflow::GraphDef* graph_def_; /*Not owned*/ +}; + +// A factory interface for cluster class. +// It defines a virtual function interface which is responsible for creating +// a cluster. Each cluster factory is responsible to pack a cluster of nodes +// into a cluster using a name-based pattern matching approach. +class ClusterFactoryInterface { + public: + virtual ~ClusterFactoryInterface() {} + + // Creates a cluster of nodes using a name-based pattern matching approach. It + // uses a node as a seed and if its name matches a certain pattern, then it + // builds the cluster around that node. + virtual std::unique_ptr CreateCluster( + const tensorflow::NodeDef& node, + const tensorflow::GraphDef& graph_def) const = 0; +}; + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc new file mode 100644 index 0000000000..14c3cd6487 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/toco/toco_types.h" +namespace toco { + +bool StrContains(const string& x, const string& search_pattern) { + return x.find(search_pattern) != string::npos; +} + +void Transpose2DTensor(const float* tensor, int row, int col, + float* transposed_tensor) { + float* result = transposed_tensor; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + *(result + c * row) = *tensor++; + } + ++result; + } +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h new file mode 100644 index 0000000000..a15e480e70 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H + +#include + +namespace toco { + +// Check if string x includes string search_pattern. +bool StrContains(const string& x, const string& search_pattern); + +// Transpose a 2D tensor of size row * col pointed by "tensor" and return the +// results in "transposed_tensor". "transposed_tensor" must be pre-allocated +// by the same size as "tensor". +void Transpose2DTensor(const float* tensor, int row, int col, + float* transposed_tensor); + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc new file mode 100644 index 0000000000..fddf6cc836 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" + +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +void AddNodeToGraph(const NodeDef& node, + const std::vector& cluster_names, GraphDef* graph) { + NodeDef* new_node = graph->add_node(); + new_node->set_op(node.op()); + new_node->set_name(node.name()); + new_node->set_device(node.device()); + // If the inputs are coming from a node which belongs to another cluster, then + // those inputs are renamed to the source cluster name. Otherwise the original + // input name is used. + for (const string& node_input : node.input()) { + bool input_from_cluster = false; + for (const string& cluster_name : cluster_names) { + if (StrContains(node_input, cluster_name) && + !StrContains(node.name(), cluster_name)) { + new_node->add_input(cluster_name); + input_from_cluster = true; + break; + } + } + if (!input_from_cluster) { + new_node->add_input(node_input); + } + } + for (const auto& attr : node.attr()) { + (*new_node->mutable_attr())[attr.first] = attr.second; + } +} + +bool FindCluster(const ClusterFactoryInterface& cluster_factory, + const GraphDef& graph_def, + std::unordered_map* is_node_in_cluster, + std::vector>* clusters) { + for (const NodeDef& node : graph_def.node()) { + // If the node is not assigned to any cluster, then we check if it belong to + // the cluster_factory. + bool node_in_cluster = (*is_node_in_cluster)[node.name()]; + if (!node_in_cluster) { + std::unique_ptr cluster = + cluster_factory.CreateCluster(node, graph_def); + if (cluster) { + // Label all the nodes in is_node_in_cluster which are in this cluster + // as belonged to this cluster. + for (const NodeDef* cluster_node : cluster->GetNodes()) { + (*is_node_in_cluster)[cluster_node->name()] = true; + } + clusters->push_back(std::move(cluster)); + } + } + } + return (!clusters->empty()); +} + +std::unique_ptr MaybeResolveClusters( + const GraphDef& graph_def, + const std::vector& cluster_factories) { + std::unique_ptr pruned_graph(new GraphDef); + // The structure to keep track of which cluster each node is assigned to, and + // to initialize them to all un-assigned, + std::unordered_map is_node_in_cluster; + for (const NodeDef& node : graph_def.node()) { + is_node_in_cluster[node.name()] = false; + } + + std::vector cluster_names; + std::vector> all_clusters; + // Find the clusters for all available cluster factories. + for (const ClusterFactoryInterface* cluster_factory : cluster_factories) { + std::vector> clusters; + if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster, + &clusters)) { + for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) { + cluster_names.push_back((*itr)->GetName()); + (*itr)->CreateNodes(); + all_clusters.push_back(std::move(*itr)); + } + } + } + + for (const std::unique_ptr& cluster : all_clusters) { + for (const std::unique_ptr& src_node : + cluster->GetNewNodes()) { + // Add it to the output GraphDef. + AddNodeToGraph(*src_node, cluster_names, pruned_graph.get()); + } + } + + // Add any node which is not part of a cluster. + for (const NodeDef& node : graph_def.node()) { + bool node_in_cluster = is_node_in_cluster[node.name()]; + if (!node_in_cluster) { + AddNodeToGraph(node, cluster_names, pruned_graph.get()); + } + } + + if (pruned_graph->node_size() == 0) { + return nullptr; + } else { + return pruned_graph; + } +} + +std::unique_ptr MaybeReplaceCompositeSubgraph( + const GraphDef& tf_graph) { + SvdfClusterFactory svdf_cluster_factory; + + std::vector cluster_factories; + cluster_factories.push_back(&svdf_cluster_factory); + + std::unique_ptr pruned_graph = + MaybeResolveClusters(tf_graph, cluster_factories); + + // Copy function definitions + *(pruned_graph->mutable_library()) = tf_graph.library(); + return pruned_graph; +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h new file mode 100644 index 0000000000..7d33dd1885 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H + +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +// Given a graph info and a list of cluster classes (cluster_factories), it +// partitions the graph to clusters, and then collapses each cluster into their +// corresponding composite ops. It generates a new graph using the newly +// generated composite ops. Each cluster factory is responsible to recognize a +// cluster of nodes into a cluster using a name-based pattern matching approach. +std::unique_ptr MaybeResolveClusters( + const tensorflow::GraphDef& graph_def, + const std::vector& cluster_factories); + +// Adds a node to a given graph. The added node will be a copy of a given source +// node, except for the inputs. If the inputs are coming from a node which +// belongs to another cluster, then those inputs are renamed to the source +// cluster name. +void AddNodeToGraph(const tensorflow::NodeDef& node, + const std::vector& cluster_names, + tensorflow::GraphDef* graph); + +// Given a graph and a cluster class, it finds all the nodes which belong to a +// given class factory, encapsulate them inside a cluster of the given type and +// returns a vector of those clusters. It also labels the nodes in that graph if +// they belong to the generated clusters. +bool FindCluster(const ClusterFactoryInterface& cluster_factory, + const tensorflow::GraphDef& graph_def, + std::unordered_map* is_node_in_cluster, + std::vector>* clusters); + +// Receives a graph and generates another graph by replacing the cluster of +// nodes which matches a given composite op. Each composite op is represented +// using a class factory. +std::unique_ptr MaybeReplaceCompositeSubgraph( + const tensorflow::GraphDef& tf_graph); + +} // end namespace toco + +#endif // CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc new file mode 100644 index 0000000000..d6a099817c --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc @@ -0,0 +1,285 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/map.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +namespace toco { + +namespace { + +// Receives a vector of cluster nodes and returns only those which are array +// partitions (of type 'Const' and have the pattern 'part_<.*>' in their name. +// Since these nodes are connected to a Concatenate node, it makes sure the +// axis value input of the Concatenate operator is 0. +void FilterPartitionedConstNodes( + const string& const_pattern, + const std::vector& cluster_nodes, + std::vector* const_node_parts) { + for (const NodeDef* node : cluster_nodes) { + string node_name_to_upper = node->name(); + std::transform(node_name_to_upper.begin(), node_name_to_upper.end(), + node_name_to_upper.begin(), ::toupper); + if (StrContains(node->name(), const_pattern) && node->op() == "Const") { + if (StrContains(node_name_to_upper, "/PART_")) { + const_node_parts->push_back(node); + } else if (StrContains(node->name(), "AXIS") && + StrContains(node->name(), "CONCAT")) { + // For now only supporting Concatenate on Axix 0 + const auto& value_attr = node->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + CHECK_EQ(tensor.int_val(0), 0); + } + } + } + sort(const_node_parts->begin(), const_node_parts->end(), + [](const NodeDef* a, const NodeDef* b) { + return (a->name().compare(b->name()) < 0 && + (a->name().size() < b->name().size())); + }); +} + +} // namespace + +// SvdfCluster methods + +int SvdfCluster::InferFilterRank() { + for (const NodeDef* node : nodes_) { + if (StrContains(node->name(), "Reshape/shape")) { + const auto& value_attr = node->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + std::vector shape_values( + tensor.tensor_content().size() / sizeof(int), 0); + port::CopyToBuffer(tensor.tensor_content(), + reinterpret_cast(shape_values.data())); + CHECK_EQ(shape_values.size(), 3); + // shape_value array is arranged as: + // [num_units, rank, -1] + CHECK_EQ(shape_values[2], -1); + return shape_values[1]; + } + } + return -1; +} + +void SvdfCluster::CreateNodes() { + for (const string& const_pattern : const_node_patterns_) { + CreateConstNode(const_pattern); + } + std::unique_ptr svdf_node(new NodeDef); + svdf_node->set_op("Svdf"); + svdf_node->set_name(name_); + svdf_node->set_device(device_); + + // Add the main input. + svdf_node->add_input(inputs_[0]); + + // Add the rest of the inputs to Svdf cell: weights and bias. + CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2); + string* weights_feature_input = svdf_node->add_input(); + string* weights_time_input = svdf_node->add_input(); + string* bias_input; + if (new_nodes_.size() == 3) { + bias_input = svdf_node->add_input(); + } + for (const std::unique_ptr& node : new_nodes_) { + const string node_name = node->name(); + if (StrContains(node_name, "SVDF_weights_feature")) { + *weights_feature_input = node_name; + } else if (StrContains(node_name, "SVDF_weights_time")) { + *weights_time_input = node_name; + } else if (StrContains(node_name, "SVDF_bias")) { + CHECK(bias_input) << "Bias input cannot be provided when there are only " + "two Const input nodes!"; + *bias_input = node_name; + } else { + // Unexpected input for Svdf op. + LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: " + "weights_feature, weights_time and bias."; + } + } + const int rank = InferFilterRank(); + CHECK_GT(rank, 0); + + // Add Svdf activation and rank. + string activation_function = + StrContains(outputs_[0], "Relu") ? "Relu" : "None"; + (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function); + (*svdf_node->mutable_attr())["Rank"].set_i(rank); + + // Finally add it to the list of the newly created nodes. + new_nodes_.push_back(std::move(svdf_node)); +} + +void SvdfCluster::CreateConstNode(const string& const_pattern) { + // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const. + std::vector const_node_parts; + FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts); + + if (const_node_parts.empty()) return; + + bool transpose_tensor_value = + StrContains(const_pattern, "SVDF_weights_feature"); + + // Merge them if necessary. + std::unique_ptr merged_node(new NodeDef); + MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node); + new_nodes_.push_back(std::move(merged_node)); +} + +void SvdfCluster::MaybeMergeConstNodes( + const std::vector& const_node_parts, + bool transpose_tensor_value, + const std::unique_ptr& merged_node) { + merged_node->set_name(const_node_parts[0]->name()); + merged_node->set_op("Const"); + merged_node->set_device(const_node_parts[0]->device()); + (*merged_node->mutable_attr())["dtype"].set_type( + const_node_parts[0]->attr().at("dtype").type()); + + // Figuring out Value attribute for the merged node. + // Assuming the partitioning is done on Axis 0. + // The attributes which are inferred: + // * Shape and dimensions + // * Float content values + + // Inferring shape and dimension + int dim0_size = 0; + int dim1_size = 1; + tensorflow::TensorProto* allocated_tensor = + (*merged_node->mutable_attr())["value"].mutable_tensor(); + tensorflow::TensorShapeProto* allocated_tensor_shape = + allocated_tensor->mutable_tensor_shape(); + auto tensor_shape_dim0 = allocated_tensor_shape->add_dim(); + int allocated_content_flat_size = 0; + for (int i = 0; i < const_node_parts.size(); i++) { + const auto& value_attr = const_node_parts[i]->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + if (i == 0) { + allocated_tensor->set_dtype(tensor.dtype()); + } else { + CHECK_EQ(allocated_tensor->dtype(), tensor.dtype()); + } + allocated_content_flat_size += tensor.tensor_content().size(); + CHECK(tensor.has_tensor_shape()); + const tensorflow::TensorShapeProto shape = tensor.tensor_shape(); + dim0_size += shape.dim(0).size(); + for (int d = 1; d < shape.dim_size(); d++) { + if (i == 0) { + allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size()); + allocated_tensor_shape->set_unknown_rank(shape.unknown_rank()); + dim1_size *= shape.dim(d).size(); + } else { + CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size()); + CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank()); + } + } + } + + // Copying the float content from each array partition. + std::unique_ptr allocated_content( + new char[allocated_content_flat_size]); + char* content_ptr = allocated_content.get(); + for (int i = 0; i < const_node_parts.size(); i++) { + const auto& value_attr = const_node_parts[i]->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + port::CopyToBuffer(tensor.tensor_content(), content_ptr); + content_ptr += tensor.tensor_content().size(); + } + + // Transpose the tensor if needed. + if (transpose_tensor_value) { + // We use dimension 0 to show the row size for the tensor. + // We use multiplication of the rest of dimension size to for the col size + // of the tensor. + std::unique_ptr transposed_tensor( + new float[dim0_size * dim1_size]); + Transpose2DTensor(reinterpret_cast(allocated_content.get()), + dim0_size, dim1_size, transposed_tensor.get()); + allocated_tensor_shape->clear_dim(); + allocated_tensor_shape->add_dim()->set_size(dim1_size); + allocated_tensor_shape->add_dim()->set_size(dim0_size); + + // Set the tensor attributes. + allocated_tensor->set_tensor_content( + string(reinterpret_cast(transposed_tensor.get()), + allocated_content_flat_size)); + } else { + tensor_shape_dim0->set_size(dim0_size); + + // Set the tensor attributes. + allocated_tensor->set_tensor_content( + string(reinterpret_cast(allocated_content.get()), + allocated_content_flat_size)); + } +} + +// SvdfClusterFactory methods + +std::unique_ptr SvdfClusterFactory::CreateCluster( + const NodeDef& node, const GraphDef& graph_def) const { + std::vector node_patterns = {"SVDF_weights_feature", + "SVDF_weights_time", "SVDF_bias"}; + + string node_name_to_upper = node.name(); + std::transform(node_name_to_upper.begin(), node_name_to_upper.end(), + node_name_to_upper.begin(), ::toupper); + std::unique_ptr cluster = nullptr; + if (node_name_to_upper.find("SVDF", 0) != string::npos) { + size_t weights_pos = node.name().find(node_patterns[0]); + if (weights_pos != string::npos) { + // Assuming the node name has a pattern like: + // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use + // CELLNAME as the cluster name. + size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1; + string cell_name = + node.name().substr(cell_pos, weights_pos - cell_pos - 1); + cluster = std::unique_ptr(new SvdfCluster); + cluster->SetName(cell_name); + cluster->SetDevice(node.device()); + cluster->SetGraphDefInfo(&graph_def); + CHECK(cluster->FindClusterInputsAndOutputs()); + + for (const string& const_pattern : node_patterns) { + cluster->AddConstNodePattern(const_pattern); + } + } + } + return std::move(cluster); +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h new file mode 100644 index 0000000000..c4c6c34117 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +class SvdfCluster : public Cluster { + public: + // For this cluster, it collapses all the nodes in nodes_ into a composite op + // and it returns all the newly generated ops in new_nodes_. + void CreateNodes() override; + + // A helper function to set the pattern of Const nodes which CreateNodes() + // should handle specially. + void AddConstNodePattern(const string& const_pattern) { + const_node_patterns_.push_back(const_pattern); + } + + virtual ~SvdfCluster() {} + + private: + // The main function which is used to create Const nodes for this cluster. + // These Const nodes are the inputs to the composite op generated for this + // cluster. + void CreateConstNode(const string& const_pattern); + + // Receives a vector of Const nodes, merge them (if necessary) and returns + // only one Const node holding all the arrays contents. It transposes it if + // needed. + void MaybeMergeConstNodes( + const std::vector& const_node_parts, + bool transpose_tensor_value, + const std::unique_ptr& merged_node); + + // Infer the value of Svdf filter rank, by looking up a reshape operator which + // is used for 'output' which reshapes output from [num_filters, batch, 1] + // shape to [num_units, rank, batch] shape. The 2nd shape element is rank. + int InferFilterRank(); + + std::vector const_node_patterns_; +}; + +class SvdfClusterFactory : public ClusterFactoryInterface { + public: + // Creates a cluster of nodes using a name-based pattern matching approach. It + // uses a node as a seed and if its name matches a certain pattern, then it + // builds the cluster around that node. + // This factory expects nodes which have "SVDF_weights_feature" and + // "SVDF_weights_time" pattern in their names (and optionally "SVDF_bias") + // and it creates an SVDF Op from them. + std::unique_ptr CreateCluster( + const tensorflow::NodeDef& node, + const tensorflow::GraphDef& graph_def) const; +}; + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc new file mode 100644 index 0000000000..664e828c19 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc @@ -0,0 +1,212 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +namespace toco { + +class ResolveSvdfTest : public ::testing::Test { + public: + ResolveSvdfTest() { + AddNewNode("Input1", "Const", {}); + AddNewNode("Svdf1/SVDF_weights_feature/part_0", "Const", {}, + {0.1, 0.2, 0.3}); + AddNewNode("Svdf1/SVDF_weights_feature/part_0/read", "Identity", + {"Svdf1/SVDF_weights_feature/part_0"}); + AddNewNode("Svdf1/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3}); + AddNewNode("Svdf1/SVDF_weights_time/part_0/read", "Identity", + {"Svdf1/SVDF_weights_time/part_0"}); + + AddNewNode("Svdf1/f1", "SVDF_F1", + {"Input1", "Svdf1/SVDF_weights_feature/part_0/read"}); + AddNewNode("Svdf1/f2", "SVDF_F2", + {"Svdf1/SVDF_weights_time/part_0/read", "Svdf1/f1"}); + AddNewNode("Svdf1/Relu", "Relu", {"Svdf1/f2"}); + AddShapeNode("Svdf1/Reshape/shape", {10, 1, -1}); + AddNewNode("Output1", "Const", {"Svdf1/Relu"}); + + AddNewNode("Input2", "Const", {}); + AddNewNode("Svdf2/SVDF_weights_feature/part_0", "Const", {}, + {0.1, 0.2, 0.3}); + AddNewNode("Svdf2/SVDF_weights_feature/part_0/read", "Identity", + {"Svdf2/SVDF_weights_feature/part_0"}); + AddNewNode("Svdf2/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3}); + AddNewNode("Svdf2/SVDF_weights_time/part_0/read", "Identity", + {"Svdf2/SVDF_weights_time/part_0"}); + + AddNewNode("Svdf2/f1", "SVDF_F1", + {"Input1", "Svdf2/SVDF_weights_feature/part_0/read"}); + AddNewNode("Svdf2/f2", "SVDF_F2", + {"Svdf2/SVDF_weights_time/part_0/read", "Svdf2/f1"}); + AddNewNode("Svdf2/Relu", "Relu", {"Svdf2/f2"}); + AddShapeNode("Svdf2/Reshape/shape", {10, 2, -1}); + AddNewNode("Output2", "Const", {"Svdf2/Relu"}); + } + + ~ResolveSvdfTest() override {} + + protected: + void AddNewNode(const string& name, const string& op, + const std::vector& inputs) { + NodeDef* node = graph_.add_node(); + node->set_name(name); + node->set_op(op); + node->set_device(""); + for (int i = 0; i < inputs.size(); i++) { + node->add_input(); + node->set_input(i, inputs[i]); + } + } + + void AddNewNode(const string& name, const string& op, + const std::vector& inputs, + const std::vector& values) { + NodeDef* node = graph_.add_node(); + node->set_name(name); + node->set_op(op); + node->set_device(""); + for (int i = 0; i < inputs.size(); i++) { + node->add_input(); + node->set_input(i, inputs[i]); + } + // Add the float vector as an attribute to the node. + (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_FLOAT); + tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto; + tensorflow::TensorShapeProto* allocated_tesnor_shape = + new tensorflow::TensorShapeProto; + auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim(); + tensor_shape_dim0->set_size(values.size()); + allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape); + allocated_tensor->set_tensor_content( + string(reinterpret_cast(values.data()), + values.size() * sizeof(float))); + (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor); + } + + void AddShapeNode(const string& name, const std::vector& values) { + NodeDef* node = graph_.add_node(); + node->set_name(name); + node->set_op("Const"); + node->set_device(""); + // Add the float vector as an attribute to the node. + (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_INT32); + tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto; + tensorflow::TensorShapeProto* allocated_tesnor_shape = + new tensorflow::TensorShapeProto; + auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim(); + tensor_shape_dim0->set_size(values.size()); + allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape); + allocated_tensor->set_tensor_content( + string(reinterpret_cast(values.data()), + values.size() * sizeof(int))); + (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor); + } + + GraphDef graph_; + SvdfClusterFactory svdf_cluster_factory_; + std::vector> clusters_; +}; + +TEST_F(ResolveSvdfTest, TestTranspose2DTensor) { + static float matrix[] = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}; + static float expected_transposed_matrix[] = {1., 5., 9., 2., 6., 10., + 3., 7., 11., 4., 8., 12.}; + float* transposed_matrix = new float[12]; + Transpose2DTensor(matrix, 3, 4, transposed_matrix); + + std::vector actual; + actual.insert( + actual.end(), transposed_matrix, + transposed_matrix + sizeof(expected_transposed_matrix) / sizeof(float)); + std::vector expected; + expected.insert(expected.end(), expected_transposed_matrix, + expected_transposed_matrix + + sizeof(expected_transposed_matrix) / sizeof(float)); + delete[] transposed_matrix; +} + +TEST_F(ResolveSvdfTest, TestResolveSvdfFlow) { + std::unordered_map is_node_in_cluster; + for (const NodeDef& node : graph_.node()) { + is_node_in_cluster[node.name()] = false; + } + + std::vector cluster_names; + CHECK(FindCluster(svdf_cluster_factory_, graph_, &is_node_in_cluster, + &clusters_)); + + for (const std::unique_ptr& cluster : clusters_) { + cluster_names.push_back(cluster->GetName()); + cluster->CreateNodes(); + } + + EXPECT_THAT(cluster_names, + testing::UnorderedElementsAreArray({"Svdf1", "Svdf2"})); + + std::vector new_node_names; + std::vector content_array(3); + for (const std::unique_ptr& cluster : clusters_) { + // After CreateNodes in each cluster we have three nodes: Svdf, + // weights_feature and weights_time. + CHECK_EQ(cluster->GetNewNodes().size(), 3); + for (const std::unique_ptr& node : + cluster->GetNewNodes()) { + new_node_names.push_back(node->name()); + if (node->op() == "Const") { + CHECK_EQ(node->attr().at("dtype").type(), tensorflow::DT_FLOAT); + toco::port::CopyToBuffer( + node->attr().at("value").tensor().tensor_content(), + reinterpret_cast(content_array.data())); + EXPECT_THAT(content_array, + testing::UnorderedElementsAreArray({0.1, 0.2, 0.3})); + } else { + // Checking the Svdf node attributes (rank and activation type) are + // correct. + if (node->name() == "Svdf1") { + CHECK_EQ(node->attr().at("Rank").i(), 1); + } else if (node->name() == "Svdf2") { + CHECK_EQ(node->attr().at("Rank").i(), 2); + } + CHECK_EQ(node->attr().at("ActivationFunction").s(), "Relu"); + } + } + } + EXPECT_THAT(new_node_names, testing::UnorderedElementsAreArray( + {"Svdf2/SVDF_weights_feature/part_0", + "Svdf2/SVDF_weights_time/part_0", "Svdf2", + "Svdf1/SVDF_weights_feature/part_0", + "Svdf1/SVDF_weights_time/part_0", "Svdf1"})); +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/contrib/lite/toco/tensorflow_util.cc new file mode 100644 index 0000000000..82e2800ca2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_util.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_util.h" + +#include +#include +#include + +#ifdef GOOGLE_PLATFORM +#include "file/logging/log_lines.h" +#endif +#include "google/protobuf/map.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +using tensorflow::AttrValue; +using tensorflow::GraphDef; + +void LogDumpGraphDef(int log_level, const string& message, + const GraphDef& tf_graph) { + if (!VLOG_IS_ON(log_level)) { + return; + } + std::set ops; + for (const auto& node : tf_graph.node()) { + ops.insert(node.op()); + } + string dump; + toco::port::AppendF(&dump, R"MSG( +BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s) +There are %d nodes. +There are %zu different op types: +)MSG", message, tf_graph.node_size(), ops.size()); + for (const auto& op : ops) { + toco::port::AppendF(&dump, " %s\n", op); + } + dump.append(R"MSG( +PROTO DUMP +)MSG"); + for (const auto& node : tf_graph.node()) { + toco::port::AppendF(&dump, R"MSG( +BEGIN NODE: name = %s + op = %s + inputs = [ +)MSG", node.name(), node.op()); + for (const auto& input : node.input()) { + toco::port::AppendF(&dump, " %s\n", input); + } + dump.append(" ]\n"); + for (const auto& attr : node.attr()) { + toco::port::AppendF(&dump, " ATTR: name = %s\n", attr.first); + if (attr.second.value_case() == AttrValue::kFunc) { + dump.append(" func\n"); + } else if (attr.second.value_case() == AttrValue::kPlaceholder) { + toco::port::AppendF(&dump, " placeholder: %s\n", + attr.second.placeholder()); + } else if (attr.second.value_case() == AttrValue::kS) { + dump.append(" string:\n"); + dump.append(R"MSG( + BEGIN EMBEDDED STRING +)MSG"); + const auto& lines = absl::StrSplit(attr.second.s(), '\n'); + for (const auto& line : lines) { + toco::port::AppendF(&dump, " %s\n", line); + } + dump.append(R"MSG( + END EMBEDDED STRING +)MSG"); + } else if (attr.second.value_case() == AttrValue::kI) { + toco::port::AppendF(&dump, " int: %lld\n", attr.second.i()); + } else if (attr.second.value_case() == AttrValue::kF) { + toco::port::AppendF(&dump, " float: %g\n", attr.second.f()); + } else if (attr.second.value_case() == AttrValue::kB) { + toco::port::AppendF(&dump, " bool: %s\n", + attr.second.b() ? "true" : "false"); + } else if (attr.second.value_case() == AttrValue::kType) { + toco::port::AppendF(&dump, " type: %s\n", + tensorflow::DataType_Name(attr.second.type())); + } else if (attr.second.value_case() == AttrValue::kShape) { + dump.append(" shape: [ "); + const auto& shape = attr.second.shape(); + for (int i = 0; i < shape.dim_size(); i++) { + toco::port::AppendF(&dump, "%lld ", shape.dim(i).size()); + } + dump.append("]\n"); + } else if (attr.second.value_case() == AttrValue::kTensor) { + const auto& tensor = attr.second.tensor(); + dump.append(" TENSOR:\n"); + toco::port::AppendF(&dump, " type: %s\n", + tensorflow::DataType_Name(tensor.dtype())); + const auto& shape = tensor.tensor_shape(); + dump.append(" shape: [ "); + for (int i = 0; i < shape.dim_size(); i++) { + toco::port::AppendF(&dump, "%lld ", shape.dim(i).size()); + } + dump.append("]\n"); + if (!tensor.tensor_content().empty()) { + toco::port::AppendF(&dump, " tensor_content: %zu bytes\n", + tensor.tensor_content().size()); + } + if (tensor.dtype() == tensorflow::DT_INT32) { + CHECK_EQ(0, tensor.tensor_content().size() % sizeof(int32)); + const int size = tensor.tensor_content().size() / sizeof(int32); + std::vector data(size); + toco::port::CopyToBuffer(tensor.tensor_content(), + reinterpret_cast(data.data())); + const int kMaxValsToPrint = 4; + dump.append(" tensor_content as ints: [ "); + for (int i = 0; i < kMaxValsToPrint && i < size; i++) { + toco::port::AppendF(&dump, "%d ", data[i]); + } + if (size > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.dtype() == tensorflow::DT_FLOAT) { + CHECK_EQ(0, tensor.tensor_content().size() % sizeof(float)); + const int size = tensor.tensor_content().size() / sizeof(float); + std::vector data(size); + toco::port::CopyToBuffer(tensor.tensor_content(), + reinterpret_cast(data.data())); + const int kMaxValsToPrint = 4; + dump.append(" tensor_content as floats: [ "); + for (int i = 0; i < kMaxValsToPrint && i < size; i++) { + toco::port::AppendF(&dump, "%g ", data[i]); + } + if (size > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.int_val_size()) { + toco::port::AppendF(&dump, " int_val: %d ints: [ ", + tensor.int_val_size()); + const int kMaxValsToPrint = 4; + for (int i = 0; i < kMaxValsToPrint && i < tensor.int_val_size(); + i++) { + toco::port::AppendF(&dump, "%d ", tensor.int_val(i)); + } + if (tensor.int_val_size() > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.float_val_size()) { + toco::port::AppendF(&dump, " float_val: %d floats: [ ", + tensor.float_val_size()); + const int kMaxValsToPrint = 4; + for (int i = 0; i < kMaxValsToPrint && i < tensor.float_val_size(); + i++) { + toco::port::AppendF(&dump, "%g ", tensor.float_val(i)); + } + if (tensor.float_val_size() > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.string_val_size()) { + toco::port::AppendF(&dump, " string_val: %d strings\n", + tensor.string_val_size()); + } + } else if (attr.second.value_case() == AttrValue::kList) { + dump.append(" LIST\n"); + } + } + dump.append("END NODE\n"); + } + toco::port::AppendF(&dump, "END DUMP OF TENSORFLOW GRAPHDEF (%s)\n", message); +#if defined(GOOGLE_PLATFORM) + VLOG_LINES(log_level, dump); +#else + VLOG(log_level) << dump; +#endif +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/contrib/lite/toco/tensorflow_util.h new file mode 100644 index 0000000000..152b4f7a72 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_util.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +void LogDumpGraphDef(int log_level, const string& message, + const tensorflow::GraphDef& tf_graph); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD new file mode 100644 index 0000000000..e910e3957f --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -0,0 +1,142 @@ +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "operator", + srcs = [ + "operator.cc", + ], + hdrs = [ + "builtin_operator.h", + "custom_operator.h", + "operator.h", + "simple_operator.h", + ], + deps = [ + ":types", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@flatbuffers//:flatbuffers", + ], +) + +tf_cc_test( + name = "operator_test", + srcs = [ + "operator_test.cc", + ], + deps = [ + ":operator", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:flatbuffers", + ], +) + +cc_library( + name = "types", + srcs = [ + "types.cc", + ], + hdrs = [ + "types.h", + ], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + ], +) + +tf_cc_test( + name = "types_test", + srcs = [ + "types_test.cc", + ], + deps = [ + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "export", + srcs = [ + "export.cc", + ], + hdrs = [ + "export.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":operator", + ":types", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_absl//absl/strings", + "@flatbuffers//:flatbuffers", + ], +) + +tf_cc_test( + name = "export_test", + srcs = [ + "export_test.cc", + ], + deps = [ + ":export", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "import", + srcs = [ + "import.cc", + ], + hdrs = [ + "import.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":operator", + ":types", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + "@flatbuffers//:flatbuffers", + ], +) + +tf_cc_test( + name = "import_test", + srcs = [ + "import_test.cc", + ], + deps = [ + ":import", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:flatbuffers", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h new file mode 100644 index 0000000000..93cc79ddb6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ + +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +namespace toco { + +namespace tflite { + +// Builtin operators have special TF Lite objects describing their options. +// This class has the boilerplate code for creating those. +// +// Template arguments: +// - T1 must derive from ::toco::Operator. +// - T2 must be one of TF Lite's objects defining Builtin Options, such as +// ::tflite::Conv2DOptions. +template +class BuiltinOperator : public BaseOperator { + public: + using TocoOperator = T1; + using TfLiteOptions = T2; + + BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type) + : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {} + + // Build the configuration object in the given flatbuffer builder. Return + // its offset. + virtual flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const = 0; + + // Read options from the TF Lite object and set the corresponding values in + // the tf.mini operator. + virtual void ReadOptions(const TfLiteOptions& opt, + TocoOperator* op) const = 0; + + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto options = WriteOptions(static_cast(op), builder); + return Options::Builtin(TfLiteEnum, options.Union()); + } + + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + auto op = absl::make_unique(); + auto* options = static_cast(builtin_options); + if (options) { + ReadOptions(*options, op.get()); + } + return std::unique_ptr(op.release()); + } +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/contrib/lite/toco/tflite/custom_operator.h new file mode 100644 index 0000000000..1a4bfac7d4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/custom_operator.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ + +#include "flatbuffers/flexbuffers.h" +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +namespace toco { + +namespace tflite { + +// Custom operators have a generic byte buffer describing their options. This +// class provides the boilerplate code for populating those options using +// flexbuffers. Note that most of toco's operators will likely be supported +// as builtin operators in TF Lite. +// +// Template argument T must derive from ::toco::Operator. +template +class CustomOperator : public BaseOperator { + public: + using TocoOperator = T; + using BaseOperator::BaseOperator; + + // Populate the given flexbuffer with options obtained from the tf.mini + // operator. + virtual void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const {} + + // Set options in the given tf.mini operator using values from the flexbuffer + // map. + virtual void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const {} + + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + flexbuffers::Builder fbb; + fbb.Map( + [&]() { WriteOptions(static_cast(op), &fbb); }); + fbb.Finish(); + return Options::Custom(builder->CreateVector(fbb.GetBuffer())); + } + + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + auto op = absl::make_unique(); + if (custom_options) { + auto flexbuffer_map = + flexbuffers::GetRoot(custom_options->data(), custom_options->size()) + .AsMap(); + ReadOptions(flexbuffer_map, op.get()); + } + return std::unique_ptr(op.release()); + } +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc new file mode 100644 index 0000000000..beda710614 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -0,0 +1,322 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/export.h" + +#include "flatbuffers/flexbuffers.h" +#include "absl/strings/str_join.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/contrib/lite/version.h" + +namespace toco { + +namespace tflite { + +using ::tflite::Buffer; +using ::tflite::BuiltinOperator; +using ::tflite::BuiltinOperator_CUSTOM; +using ::tflite::BuiltinOperator_MAX; +using ::tflite::BuiltinOperator_MIN; +using ::tflite::CreateBuffer; +using ::tflite::CreateModel; +using ::tflite::CreateOperator; +using ::tflite::CreateTensor; +using ::tflite::Operator; +using ::tflite::OperatorCode; +using ::tflite::SubGraph; +using ::tflite::Tensor; +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; + +namespace { + +details::OperatorKey GetOperatorKey(const ::toco::Operator& op) { + string custom_code; + if (op.type == OperatorType::kTensorFlowUnsupported) { + const TensorFlowUnsupportedOperator& unsupported_op = + static_cast(op); + custom_code = unsupported_op.tensorflow_op; + } + return details::OperatorKey(op.type, custom_code); +} + +} // Anonymous namespace. + +namespace details { + +void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { + // First find a list of unique array names. + std::set names; + for (const auto& array_pair : model.arrays) { + names.insert(array_pair.first); + } + + // Now assign indices to them and fill in the map. + int index = 0; + for (const auto& name : names) { + (*tensors_map)[name] = index; + ++index; + } +} + +void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) { + // First find a list of unique operator types. + std::set keys; + for (const auto& op : model.operators) { + keys.insert(GetOperatorKey(*op)); + } + // Now assign indices to them and fill in the map. + int index = 0; + for (const auto& key : keys) { + (*operators_map)[key] = index; + ++index; + } +} +} // namespace details + +Offset>> ExportTensors( + const Model& model, const details::TensorsMap& tensors_map, + FlatBufferBuilder* builder, std::vector* buffers_to_write) { + // In the end we will need to produce a vector sorted by the indices of the + // tensors in the tensors_map. + std::map> ordered_tensors; + + for (const auto& array_pair : model.arrays) { + const string& tensor_name = array_pair.first; + const toco::Array& array = *array_pair.second; + + int buffer_index = buffers_to_write->size(); + auto type = DataType::Serialize(array.data_type); + buffers_to_write->push_back(&array); + + std::vector shape; + if (array.has_shape()) { + for (int d : array.shape().dims()) { + shape.push_back(d); + } + } + + Offset> min; + Offset> max; + Offset> scale; + Offset> zero_point; + if (array.minmax) { + min = builder->CreateVector( + std::vector{static_cast(array.minmax->min)}); + max = builder->CreateVector( + std::vector{static_cast(array.minmax->max)}); + } + if (array.quantization_params) { + scale = builder->CreateVector(std::vector{ + static_cast(array.quantization_params->scale)}); + zero_point = builder->CreateVector( + std::vector{array.quantization_params->zero_point}); + } + auto q_param = ::tflite::CreateQuantizationParameters(*builder, min, max, + scale, zero_point); + + int index = tensors_map.at(tensor_name); + ordered_tensors[index] = + CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index, + builder->CreateString(tensor_name), q_param); + } + + std::vector> tensor_vector; + tensor_vector.reserve(ordered_tensors.size()); + for (const auto& tensor : ordered_tensors) { + tensor_vector.push_back(tensor.second); + } + + return builder->CreateVector(tensor_vector); +} + +Offset> ExportInputTensors( + const Model& model, const details::TensorsMap& tensors_map, + FlatBufferBuilder* builder) { + std::vector inputs; + for (const auto& input : model.flags.input_arrays()) { + inputs.push_back(tensors_map.at(input.name())); + } + return builder->CreateVector(inputs); +} + +Offset> ExportOutputTensors( + const Model& model, const details::TensorsMap& tensors_map, + FlatBufferBuilder* builder) { + std::vector outputs; + for (const string& output : model.flags.output_arrays()) { + outputs.push_back(tensors_map.at(output)); + } + return builder->CreateVector(outputs); +} + +Offset>> ExportOperatorCodes( + const Model& model, + const std::map>& ops_by_type, + const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, + std::set* error_summary) { + // Map from operator name to TF Lite enum value, for all builtins. + std::map builtin_ops; + for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { + BuiltinOperator op = static_cast(i); + string name = EnumNameBuiltinOperator(op); + if (op != BuiltinOperator_CUSTOM && !name.empty()) { + builtin_ops[name] = op; + } + } + + // We will need to produce a vector of codes in the same order as they + // appear in the operators_map. + std::map> ordered_opcodes; + + for (const auto& op : model.operators) { + const details::OperatorKey operator_key = GetOperatorKey(*op); + int op_index = operators_map.at(operator_key); + + if (ops_by_type.count(op->type) == 0) { + LOG(FATAL) << "Unsupported operator: " << HelpfulOperatorTypeName(*op); + } + + string name = ops_by_type.at(op->type)->name(); + if (builtin_ops.count(name) > 0) { + ordered_opcodes[op_index] = + CreateOperatorCode(*builder, builtin_ops[name], 0); + } else { + // If use the custom operation code if it's available in the OperatorKey. + if (!operator_key.custom_code.empty()) { + name = operator_key.custom_code; + } + if (error_summary) { + error_summary->insert(name); + } + ordered_opcodes[op_index] = CreateOperatorCode( + *builder, BuiltinOperator_CUSTOM, builder->CreateString(name)); + } + } + + std::vector> opcode_vector; + opcode_vector.reserve(ordered_opcodes.size()); + for (const auto& opcode : ordered_opcodes) { + opcode_vector.push_back(opcode.second); + } + + return builder->CreateVector(opcode_vector); +} + +Offset>> ExportOperators( + const Model& model, + const std::map>& ops_by_type, + const details::OperatorsMap& operators_map, + const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) { + // The operators are in execution order, so we just follow tf.mini order. + std::vector> op_vector; + for (const auto& op : model.operators) { + if (ops_by_type.count(op->type) == 0) { + LOG(FATAL) << "Op type '" << OperatorTypeName(op->type) + << "' not supported"; + } + + std::vector inputs; + for (const string& input : op->inputs) { + inputs.push_back(tensors_map.at(input)); + } + + std::vector outputs; + for (const string& output : op->outputs) { + outputs.push_back(tensors_map.at(output)); + } + + auto options = ops_by_type.at(op->type)->Serialize(*op, builder); + int op_index = operators_map.at(GetOperatorKey(*op)); + // The only supported CustomOptionFormat is FLEXBUFFERS now. + op_vector.push_back(CreateOperator( + *builder, op_index, builder->CreateVector(inputs), + builder->CreateVector(outputs), options.type, options.builtin, + options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + } + + return builder->CreateVector(op_vector); +} + +Offset>> ExportBuffers( + const Model& model, const std::vector& buffers_to_write, + FlatBufferBuilder* builder) { + std::vector> buffer_vector; + size_t index = 0; + for (const Array* array_ptr : buffers_to_write) { + const Array& array = *array_ptr; + Offset> data_buffer = DataBuffer::Serialize(array, builder); + buffer_vector.push_back(CreateBuffer(*builder, data_buffer)); + index++; + } + return builder->CreateVector(buffer_vector); +} + +void Export(const Model& model, bool allow_custom_ops, + string* output_file_contents) { + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); + + const auto ops_by_type = BuildOperatorByTypeMap(); + + details::TensorsMap tensors_map; + details::LoadTensorsMap(model, &tensors_map); + + details::OperatorsMap operators_map; + details::LoadOperatorsMap(model, &operators_map); + + std::vector buffers_to_write; + Array empty_array; + buffers_to_write.push_back(&empty_array); + + auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write); + auto inputs = ExportInputTensors(model, tensors_map, &builder); + auto outputs = ExportOutputTensors(model, tensors_map, &builder); + + std::set error_summary; + auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, + &builder, &error_summary); + if (!allow_custom_ops && !error_summary.empty()) { + LOG(QFATAL) << "Some of the operators in the model are not supported by " + "the standard TensorFlow Lite runtime. If you have a custom " + "implementation for them you can disable this error with " + "--allow_custom_ops. Here is a list of operators for which " + "you will need custom implementations: " + << absl::StrJoin(error_summary, ", ") << "."; + } + + auto ops = + ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder); + + // TODO(aselle): add support to toco for multiple subgraphs. + auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops); + std::vector> subgraphs = {subgraph}; + + auto buffers = ExportBuffers(model, buffers_to_write, &builder); + auto description = builder.CreateString("TOCO Converted."); + auto new_model_location = + CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, + builder.CreateVector(subgraphs), description, buffers); + ::tflite::FinishModelBuffer(builder, new_model_location); + const uint8_t* buffer = builder.GetBufferPointer(); + int size = builder.GetSize(); + *output_file_contents = string(reinterpret_cast(buffer), size); +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h new file mode 100644 index 0000000000..44012b7126 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ + +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the +// result in the given string. +void Export(const Model& model, bool allow_custom_ops, + string* output_file_contents); +// This if backward-compatibility. +inline void Export(const Model& model, string* output_file_contents) { + Export(model, true, output_file_contents); +} + +namespace details { + +// A maps from tensor name to its final position in the TF Lite buffer. +using TensorsMap = std::unordered_map; + +// A key to identify an operator. +// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to +// identify which operation is used. +struct OperatorKey { + OperatorKey(OperatorType type, const std::string& custom_code) + : type(type), custom_code(custom_code) {} + const OperatorType type; + const std::string custom_code; + + bool operator<(const OperatorKey& other) const { + if (type < other.type) return true; + if (type > other.type) return false; + return custom_code < other.custom_code; + } + + bool operator==(const OperatorKey& other) const { + return type == other.type && custom_code == other.custom_code; + } + + struct Hash { + std::size_t operator()(const OperatorKey& key) const { + return std::hash()(static_cast(key.type)) ^ + std::hash()(key.custom_code); + } + }; +}; + +// A maps from operator type to its final position in the TF Lite buffer. +using OperatorsMap = std::unordered_map; + +void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); +void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map); + +} // namespace details +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc new file mode 100644 index 0000000000..e395645383 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/export.h" + +#include +#include + +namespace toco { + +namespace tflite { +namespace { + +class ExportTest : public ::testing::Test { + protected: + // This is a very simplistic model. We are not interested in testing all the + // details here, since tf.mini's testing framework will be exercising all the + // conversions multiple times, and the conversion of operators is tested by + // separate unittests. + void BuildTestModel() { + input_model_.GetOrCreateArray("tensor_one"); + input_model_.GetOrCreateArray("tensor_two"); + input_model_.operators.emplace_back(new ConvOperator); + input_model_.operators.emplace_back(new AddOperator); + auto unsupported_operator = new TensorFlowUnsupportedOperator; + unsupported_operator->tensorflow_op = "MyCrazyOp"; + input_model_.operators.emplace_back(unsupported_operator); + } + + Model input_model_; +}; + +TEST_F(ExportTest, LoadTensorsMap) { + BuildTestModel(); + + details::TensorsMap tensors; + details::LoadTensorsMap(input_model_, &tensors); + EXPECT_EQ(0, tensors["tensor_one"]); + EXPECT_EQ(1, tensors["tensor_two"]); +} + +TEST_F(ExportTest, LoadOperatorsMap) { + BuildTestModel(); + + details::OperatorsMap operators; + details::LoadOperatorsMap(input_model_, &operators); + EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]); + EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]); + EXPECT_EQ(2, operators[details::OperatorKey( + OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]); +} + +// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc new file mode 100644 index 0000000000..bbf201fd28 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/import.h" + +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +namespace toco { + +namespace tflite { + +namespace details { +void LoadTensorsTable(const ::tflite::Model& input_model, + TensorsTable* tensors_table) { + // TODO(aselle): add support to toco for multiple subgraphs. + auto tensors = (*input_model.subgraphs())[0]->tensors(); + if (!tensors) return; + for (const auto* tensor : *tensors) { + tensors_table->push_back(tensor->name()->c_str()); + } +} + +void LoadOperatorsTable(const ::tflite::Model& input_model, + OperatorsTable* operators_table) { + auto opcodes = input_model.operator_codes(); + if (!opcodes) return; + for (const auto* opcode : *opcodes) { + if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { + operators_table->push_back( + EnumNameBuiltinOperator(opcode->builtin_code())); + } else { + operators_table->push_back(opcode->custom_code()->c_str()); + } + } +} +} // namespace details + +void ImportTensors(const ::tflite::Model& input_model, Model* model) { + auto tensors = (*input_model.subgraphs())[0]->tensors(); + auto* buffers = input_model.buffers(); + // auto tensors = input_model.tensors(); + if (!tensors) return; + for (const auto* input_tensor : *tensors) { + Array& array = model->GetOrCreateArray(input_tensor->name()->c_str()); + array.data_type = DataType::Deserialize(input_tensor->type()); + int buffer_index = input_tensor->buffer(); + auto* buffer = buffers->Get(buffer_index); + DataBuffer::Deserialize(*input_tensor, *buffer, &array); + + auto shape = input_tensor->shape(); + if (shape) { + for (int i = 0; i < shape->Length(); ++i) { + auto d = shape->Get(i); + array.mutable_shape()->mutable_dims()->push_back(d); + } + } + + auto quantization = input_tensor->quantization(); + if (quantization) { + // Note that tf.mini only supports a single quantization parameters for + // the whole array. + if (quantization->min() && quantization->max()) { + CHECK_EQ(1, quantization->min()->Length()); + CHECK_EQ(1, quantization->max()->Length()); + MinMax& minmax = array.GetOrCreateMinMax(); + minmax.min = quantization->min()->Get(0); + minmax.max = quantization->max()->Get(0); + } + if (quantization->scale() && quantization->zero_point()) { + CHECK_EQ(1, quantization->scale()->Length()); + CHECK_EQ(1, quantization->zero_point()->Length()); + QuantizationParams& q = array.GetOrCreateQuantizationParams(); + q.scale = quantization->scale()->Get(0); + q.zero_point = quantization->zero_point()->Get(0); + } + } + } +} + +void ImportOperators( + const ::tflite::Model& input_model, + const std::map>& ops_by_name, + const details::TensorsTable& tensors_table, + const details::OperatorsTable& operators_table, Model* model) { + // TODO(aselle): add support for multiple subgraphs. + auto ops = (*input_model.subgraphs())[0]->operators(); + + if (!ops) return; + for (const auto* input_op : *ops) { + int index = input_op->opcode_index(); + if (index < 0 || index > operators_table.size()) { + LOG(FATAL) << "Index " << index << " must be between zero and " + << operators_table.size(); + } + string opname = operators_table.at(index); + if (ops_by_name.count(opname) == 0) { + LOG(FATAL) << "Op '" << opname << "' not supported"; + } + + auto new_op = ops_by_name.at(opname)->Deserialize( + input_op->builtin_options(), input_op->custom_options()); + model->operators.emplace_back(new_op.release()); + auto* op = model->operators.back().get(); + + auto inputs = input_op->inputs(); + for (int i = 0; i < inputs->Length(); i++) { + auto input_index = inputs->Get(i); + const string& input_name = tensors_table.at(input_index); + op->inputs.push_back(input_name); + } + auto outputs = input_op->outputs(); + for (int i = 0; i < outputs->Length(); i++) { + auto output_index = outputs->Get(i); + const string& output_name = tensors_table.at(output_index); + op->outputs.push_back(output_name); + } + } +} + +void ImportIOTensors(const ::tflite::Model& input_model, + const details::TensorsTable& tensors_table, Model* model) { + auto inputs = (*input_model.subgraphs())[0]->inputs(); + if (inputs) { + for (int input : *inputs) { + const string& input_name = tensors_table.at(input); + model->flags.add_input_arrays()->set_name(input_name); + } + } + + auto outputs = (*input_model.subgraphs())[0]->outputs(); + if (outputs) { + for (int output : *outputs) { + const string& output_name = tensors_table.at(output); + model->flags.add_output_arrays(output_name); + } + } +} + +std::unique_ptr Import(const ModelFlags& model_flags, + const string& input_file_contents) { + const ::tflite::Model* input_model = + ::tflite::GetModel(input_file_contents.data()); + + // Full list of all known operators. + const auto ops_by_name = BuildOperatorByNameMap(); + + if (input_model->subgraphs()->size() != 1) { + LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now."; + } + std::unique_ptr model; + model.reset(new Model); + + details::TensorsTable tensors_table; + details::LoadTensorsTable(*input_model, &tensors_table); + + details::OperatorsTable operators_table; + details::LoadOperatorsTable(*input_model, &operators_table); + + ImportTensors(*input_model, model.get()); + ImportOperators(*input_model, ops_by_name, tensors_table, operators_table, + model.get()); + ImportIOTensors(*input_model, tensors_table, model.get()); + + return model; +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/contrib/lite/toco/tflite/import.h new file mode 100644 index 0000000000..3c27a2843c --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/import.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ + +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +// Parse the given string as TF Lite flatbuffer and return a new tf.mini model. +std::unique_ptr Import(const ModelFlags &model_flags, + const string &input_file_contents); + +namespace details { + +// The names of all tensors found in a TF Lite model. +using TensorsTable = std::vector; + +// The names of all operators found in TF Lite model. If the operator is +// builtin, the string representation of the corresponding enum value is used +// as name. +using OperatorsTable = std::vector; + +void LoadTensorsTable(const ::tflite::Model &input_model, + TensorsTable *tensors_table); +void LoadOperatorsTable(const ::tflite::Model &input_model, + OperatorsTable *operators_table); + +} // namespace details +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc new file mode 100644 index 0000000000..309fa6d7f6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/import.h" + +#include "flatbuffers/flexbuffers.h" +#include +#include +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace toco { + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class ImportTest : public ::testing::Test { + protected: + template + flatbuffers::Offset> CreateDataVector( + const std::vector& data) { + return builder_.CreateVector(reinterpret_cast(data.data()), + sizeof(T) * data.size()); + } + // This is a very simplistic model. We are not interested in testing all the + // details here, since tf.mini's testing framework will be exercising all the + // conversions multiple times, and the conversion of operators is tested by + // separate unittests. + void BuildTestModel() { + // The tensors + auto q = ::tflite::CreateQuantizationParameters( + builder_, + /*min=*/builder_.CreateVector({0.1f}), + /*max=*/builder_.CreateVector({0.2f}), + /*scale=*/builder_.CreateVector({0.3f}), + /*zero_point=*/builder_.CreateVector({100ll})); + auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector({})); + auto buf1 = + ::tflite::CreateBuffer(builder_, CreateDataVector({1.0f, 2.0f})); + auto buf2 = + ::tflite::CreateBuffer(builder_, CreateDataVector({3.0f})); + auto buffers = builder_.CreateVector( + std::vector>({buf0, buf1, buf2})); + auto t1 = ::tflite::CreateTensor(builder_, + builder_.CreateVector({1, 2, 3, 4}), + ::tflite::TensorType_FLOAT32, 1, + builder_.CreateString("tensor_one"), q); + auto t2 = + ::tflite::CreateTensor(builder_, builder_.CreateVector({2, 1}), + ::tflite::TensorType_FLOAT32, 2, + builder_.CreateString("tensor_two"), q); + auto tensors = builder_.CreateVector( + std::vector>({t1, t2})); + + // The operator codes. + auto c1 = + ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM, + builder_.CreateString("custom_op_one")); + auto c2 = ::tflite::CreateOperatorCode( + builder_, ::tflite::BuiltinOperator_CONV_2D, 0); + auto opcodes = builder_.CreateVector( + std::vector>({c1, c2})); + + auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0); + std::vector> subgraph_vector( + {subgraph}); + auto subgraphs = builder_.CreateVector(subgraph_vector); + auto s = builder_.CreateString(""); + builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, + opcodes, subgraphs, s, buffers)); + + input_model_ = ::tflite::GetModel(builder_.GetBufferPointer()); + } + string InputModelAsString() { + return string(reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); + } + flatbuffers::FlatBufferBuilder builder_; + // const uint8_t* buffer_ = nullptr; + const ::tflite::Model* input_model_ = nullptr; +}; + +TEST_F(ImportTest, LoadTensorsTable) { + BuildTestModel(); + + details::TensorsTable tensors; + details::LoadTensorsTable(*input_model_, &tensors); + EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two")); +} + +TEST_F(ImportTest, LoadOperatorsTable) { + BuildTestModel(); + + details::OperatorsTable operators; + details::LoadOperatorsTable(*input_model_, &operators); + EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D")); +} + +TEST_F(ImportTest, Tensors) { + BuildTestModel(); + + auto model = Import(ModelFlags(), InputModelAsString()); + + ASSERT_GT(model->arrays.count("tensor_one"), 0); + Array& a1 = model->GetArray("tensor_one"); + EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); + EXPECT_THAT(a1.GetBuffer().data, + ElementsAre(1.0f, 2.0f)); + ASSERT_TRUE(a1.has_shape()); + EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4)); + + const auto& mm = a1.minmax; + ASSERT_TRUE(mm.get()); + EXPECT_FLOAT_EQ(0.1, mm->min); + EXPECT_FLOAT_EQ(0.2, mm->max); + + const auto& q = a1.quantization_params; + ASSERT_TRUE(q.get()); + EXPECT_FLOAT_EQ(0.3, q->scale); + EXPECT_EQ(100, q->zero_point); +} + +// TODO(ahentz): still need tests for Operators and IOTensors. + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc new file mode 100644 index 0000000000..8a33500ddc --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -0,0 +1,627 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +namespace tflite { + +class AveragePool + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, + op.stride_height, op.kwidth, + op.kheight, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->kwidth = options.filter_width(); + op->kheight = options.filter_height(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Convolution + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, + op.stride_height, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class DepthwiseConvolution + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateDepthwiseConv2DOptions( + *builder, padding, op.stride_width, op.stride_height, + op.depth_multiplier, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->depth_multiplier = options.depth_multiplier(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Add : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateAddOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Cast : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("src_data_type", DataType::Serialize(op.src_data_type)); + fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type)); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64()); + op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64()); + } +}; + +class Concatenation + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateConcatenationOptions(*builder, op.concat_dim); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->concat_dim = options.axis(); + } +}; + +class DepthToSpace : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("block_size", op.block_size); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->block_size = m["block_size"].AsInt64(); + } +}; + +class FakeQuant : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Float("min", op.minmax->min); + fbb->Float("max", op.minmax->max); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + auto* minmax = new MinMax; + minmax->min = m["min"].AsFloat(); + minmax->max = m["max"].AsFloat(); + op->minmax.reset(minmax); + } +}; + +class FullyConnected + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Svdf : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + op->rank = options.rank(); + } +}; + +class L2Normalization + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateL2NormOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class L2Pool : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, + op.stride_height, op.kwidth, + op.kheight, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->kwidth = options.filter_width(); + op->kheight = options.filter_height(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class LocalResponseNormalization + : public BuiltinOperator< + LocalResponseNormalizationOperator, + ::tflite::LocalResponseNormalizationOptions, + ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateLocalResponseNormalizationOptions( + *builder, op.range, op.bias, op.alpha, op.beta); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->range = options.radius(); + op->bias = options.bias(); + op->alpha = options.alpha(); + op->beta = options.beta(); + } +}; + +class MaxPool : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, + op.stride_height, op.kwidth, + op.kheight, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->kwidth = options.filter_width(); + op->kheight = options.filter_height(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Mul : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateMulOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Reshape + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReshapeOptions(*builder, + builder->CreateVector(op.shape)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->shape.insert(op->shape.end(), options.new_shape()->begin(), + options.new_shape()->end()); + } +}; + +class Softmax + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSoftmaxOptions(*builder, op.beta); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->beta = options.beta(); + } +}; + +class SpaceToDepth + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->block_size = options.block_size(); + } +}; + +class Split : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("num_split", op.num_split); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->num_split = m["num_split"].AsInt64(); + } +}; + +class TensorFlowUnsupported : public BaseOperator { + public: + using BaseOperator::BaseOperator; + + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto fbb = + WriteOptions(static_cast(op)); + if (fbb) { + return Options::Custom(builder->CreateVector(fbb->GetBuffer())); + } else { + return Options::Custom(0); + } + } + + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + auto op = absl::make_unique(); + if (custom_options) { + auto flexbuffer_map = + flexbuffers::GetRoot(custom_options->data(), custom_options->size()) + .AsMap(); + ReadOptions(flexbuffer_map, op.get()); + } + return std::unique_ptr(op.release()); + } + + std::unique_ptr WriteOptions( + const TensorFlowUnsupportedOperator& op) const { + auto fbb = absl::make_unique(); + + ::tensorflow::NodeDef node_def; + if (!node_def.ParseFromString(op.tensorflow_node_def)) { + LOG(ERROR) << "Failed to parse TensorFlow NodeDef"; + return std::unique_ptr(); + } + + bool has_valid_attr = false; + size_t map_start = fbb->StartMap(); + for (const auto& pair : node_def.attr()) { + const char* key = pair.first.c_str(); + const auto& attr = pair.second; + switch (attr.value_case()) { + case ::tensorflow::AttrValue::kS: + fbb->String(key, attr.s()); + has_valid_attr = true; + break; + case ::tensorflow::AttrValue::kI: + fbb->Int(key, attr.i()); + has_valid_attr = true; + break; + case ::tensorflow::AttrValue::kF: + fbb->Float(key, attr.f()); + has_valid_attr = true; + break; + case ::tensorflow::AttrValue::kB: + fbb->Bool(key, attr.b()); + has_valid_attr = true; + break; + default: + LOG(WARNING) << "Ignoring unsupported attribute type with key '" + << key << "'"; + break; + } + } + if (!has_valid_attr) { + return std::unique_ptr(); + } + fbb->EndMap(map_start); + fbb->Finish(); + return std::unique_ptr(fbb.release()); + } + + void ReadOptions(const flexbuffers::Map& m, + TensorFlowUnsupportedOperator* op) const { + ::tensorflow::NodeDef node_def; + auto attr = node_def.mutable_attr(); + + const auto& keys = m.Keys(); + for (size_t i = 0; i < keys.size(); ++i) { + const auto key = keys[i].AsKey(); + const auto& value = m[key]; + switch (value.GetType()) { + case flexbuffers::TYPE_STRING: + (*attr)[key].set_s(value.AsString().c_str()); + break; + case flexbuffers::TYPE_INT: + (*attr)[key].set_i(value.AsInt64()); + break; + case flexbuffers::TYPE_FLOAT: + (*attr)[key].set_f(value.AsFloat()); + break; + case flexbuffers::TYPE_BOOL: + (*attr)[key].set_b(value.AsBool()); + break; + default: + LOG(WARNING) << "Ignoring unsupported attribute type with key '" + << key << "'"; + break; + } + } + node_def.SerializeToString(&op->tensorflow_node_def); + } +}; + +namespace { +// Build a vector containing all the known operators. +std::vector> BuildOperatorList() { + std::vector> ops; + + // Builtin Operators. + ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); + ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + OperatorType::kAveragePool)); + ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION, + OperatorType::kConcatenation)); + ops.emplace_back( + new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv)); + ops.emplace_back( + new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + OperatorType::kDepthwiseConv)); + ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED, + OperatorType::kFullyConnected)); + ops.emplace_back( + new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION, + OperatorType::kL2Normalization)); + ops.emplace_back( + new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool)); + ops.emplace_back(new LocalResponseNormalization( + ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + OperatorType::kLocalResponseNormalization)); + ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D, + OperatorType::kMaxPool)); + ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); + ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, + OperatorType::kTensorFlowReshape)); + ops.emplace_back( + new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); + ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, + OperatorType::kSpaceToDepth)); + ops.emplace_back( + new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); + + // Custom Operators. + ops.emplace_back(new Cast("CAST", OperatorType::kCast)); + ops.emplace_back( + new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); + ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); + ops.emplace_back(new Split("SPLIT", OperatorType::kTensorFlowSplit)); + ops.emplace_back(new TensorFlowUnsupported( + "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); + + // There operators are supported by Toco, but not by TF Lite, and has no + // attributes. + ops.emplace_back(new SimpleOperator( + "RSQRT", OperatorType::kTensorFlowRsqrt)); + ops.emplace_back( + new SimpleOperator("DIV", OperatorType::kDiv)); + + // Simple Operators. + ops.emplace_back(new SimpleOperator( + "DEQUANTIZE", OperatorType::kDequantize)); + ops.emplace_back( + new SimpleOperator("FLOOR", OperatorType::kFloor)); + ops.emplace_back( + new SimpleOperator("GATHER", OperatorType::kGather)); + ops.emplace_back( + new SimpleOperator("RELU", OperatorType::kRelu)); + ops.emplace_back( + new SimpleOperator("RELU1", OperatorType::kRelu1)); + ops.emplace_back( + new SimpleOperator("RELU6", OperatorType::kRelu6)); + ops.emplace_back(new SimpleOperator( + "RESIZE_BILINEAR", OperatorType::kResizeBilinear)); + ops.emplace_back(new SimpleOperator( + "LOGISTIC", OperatorType::kLogistic)); + ops.emplace_back( + new SimpleOperator("TANH", OperatorType::kTanh)); + + return ops; +} +} // namespace + +std::map> BuildOperatorByTypeMap() { + std::map> result; + + std::vector> ops = BuildOperatorList(); + for (auto& op : ops) { + result[op->type()] = std::move(op); + } + + return result; +} + +std::map> BuildOperatorByNameMap() { + std::map> result; + + std::vector> ops = BuildOperatorList(); + for (auto& op : ops) { + result[op->name()] = std::move(op); + } + + return result; +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h new file mode 100644 index 0000000000..37df302d46 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ + +#include "flatbuffers/flatbuffers.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +class BaseOperator; + +// Return a map contained all knwo TF Lite Operators, keyed by their names. +std::map> BuildOperatorByNameMap(); + +// Return a map contained all knwo TF Lite Operators, keyed by the type of +// their tf.mini counterparts. +std::map> BuildOperatorByTypeMap(); + +// These are the flatbuffer types for custom and builtin options. +using CustomOptions = flatbuffers::Vector; +using BuiltinOptions = void; + +// A simple wrapper around the flatbuffer objects used to describe options that +// configure operators. +struct Options { + // Build custom options. + static Options Custom(flatbuffers::Offset offset) { + return {::tflite::BuiltinOptions_NONE, 0, offset}; + } + + // Build builtin options of the given type. + static Options Builtin(::tflite::BuiltinOptions type, + flatbuffers::Offset offset) { + return {type, offset, 0}; + } + + ::tflite::BuiltinOptions type; + flatbuffers::Offset builtin; + flatbuffers::Offset custom; +}; + +// A BaseOperator encapsulates the relationship between operators in tf.mini +// and TF lite, and provides methods for converting between those two formats. +class BaseOperator { + public: + // Build an operator with the given TF Lite name and tf.mini type. + BaseOperator(const string& name, OperatorType type) + : name_(name), type_(type) {} + virtual ~BaseOperator() = default; + + string name() const { return name_; } + OperatorType type() const { return type_; } + + // Given a tf.mini operator, create the corresponding flatbuffer options and + // return their offsets. + virtual Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const = 0; + + // Read TF Lite options and create the appropriate tf.mini operator. + virtual std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const = 0; + + private: + string name_; + OperatorType type_; +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc new file mode 100644 index 0000000000..543a9bd06c --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -0,0 +1,370 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +#include "flatbuffers/flexbuffers.h" +#include +#include +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +namespace tflite { +namespace { + +class OperatorTest : public ::testing::Test { + protected: + // Return the operator for the given name and type. + const BaseOperator& GetOperator(const string& name, OperatorType type) { + using OpsByName = std::map>; + using OpsByType = std::map>; + + static auto* by_name = new OpsByName(BuildOperatorByNameMap()); + static auto* by_type = new OpsByType(BuildOperatorByTypeMap()); + + // Make sure the two maps were consitently built. + CHECK(by_name->count(name)) << "No operator for '" << name << "'."; + BaseOperator* op1 = by_name->at(name).get(); + CHECK(op1->type() == type) << "while verifying '" << name << "'."; + + CHECK(by_type->count(type)) + << "No operator for '" << OperatorTypeName(type) << "'."; + BaseOperator* op2 = by_type->at(type).get(); + CHECK(op2->name() == name) + << "while verifying '" << OperatorTypeName(type) << "'."; + + return *op1; + } + + // Use the given BaseOperator to serialize the tf.mini operator into a set of + // TF Lite options. Proceed to deserialize the options back into a new + // tf.mini operator, which is then returned. If `options` is given, it will + // be populated with the serialized options. + template + std::unique_ptr SerializeAndDeserialize(const BaseOperator& op, + const T& toco_op, + Options* options = nullptr) { + flatbuffers::FlatBufferBuilder builder; + Options input_options = op.Serialize(toco_op, &builder); + + if (options) { + *options = input_options; + } + + builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type, + input_options.builtin, input_options.custom, + ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + auto* output_options = + flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer()); + auto new_toco_op = op.Deserialize(output_options->builtin_options(), + output_options->custom_options()); + + CHECK(dynamic_cast(new_toco_op.get())) + << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " + << HelpfulOperatorTypeName(toco_op); + + return std::unique_ptr(dynamic_cast(new_toco_op.release())); + } + + // Verify serialization and deserialization of simple operators (those + // that don't have any configuration parameters). + template + void CheckSimpleOperator(const string& name, OperatorType type) { + Options options; + auto output_toco_op = + SerializeAndDeserialize(GetOperator(name, type), T(), &options); + + ASSERT_EQ(0, options.builtin.o); + ASSERT_EQ(0, options.custom.o); + ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type); + + ASSERT_NE(nullptr, output_toco_op.get()); + } +}; + +TEST_F(OperatorTest, SimpleOperators) { + CheckSimpleOperator("DEQUANTIZE", + OperatorType::kDequantize); + CheckSimpleOperator("FLOOR", OperatorType::kFloor); + CheckSimpleOperator("GATHER", OperatorType::kGather); + CheckSimpleOperator("RELU", OperatorType::kRelu); + CheckSimpleOperator("RELU1", OperatorType::kRelu1); + CheckSimpleOperator("RELU6", OperatorType::kRelu6); + CheckSimpleOperator("RESIZE_BILINEAR", + OperatorType::kResizeBilinear); + CheckSimpleOperator("LOGISTIC", OperatorType::kLogistic); + CheckSimpleOperator("TANH", OperatorType::kTanh); +} + +TEST_F(OperatorTest, BuiltinAdd) { + AddOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, CustomCast) { + CastOperator op; + op.src_data_type = ArrayDataType::kFloat; + op.dst_data_type = ArrayDataType::kUint8; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op); + EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type); + EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type); +} + +TEST_F(OperatorTest, CustomConcatenation) { + ConcatenationOperator op; + op.concat_dim = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("CONCATENATION", OperatorType::kConcatenation), op); + EXPECT_EQ(op.concat_dim, output_toco_op->concat_dim); +} + +TEST_F(OperatorTest, CustomDepthToSpace) { + DepthToSpaceOperator op; + op.block_size = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op); + EXPECT_EQ(op.block_size, output_toco_op->block_size); +} + +TEST_F(OperatorTest, CustomFakeQuant) { + FakeQuantOperator op; + auto* minmax = new MinMax; + minmax->min = -10; + minmax->max = 200; + op.minmax.reset(minmax); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op); + EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min); + EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max); +} + +TEST_F(OperatorTest, CustomFullyConnected) { + FullyConnectedOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinL2Pool) { + L2PoolOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.kwidth = 480; + op.kheight = 1080; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.kwidth, output_toco_op->kwidth); + EXPECT_EQ(op.kheight, output_toco_op->kheight); +} + +TEST_F(OperatorTest, BuiltinLocalResponseNormalization) { + LocalResponseNormalizationOperator op; + op.range = 123; + op.bias = 1.23; + op.alpha = 12.3; + op.beta = .123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("LOCAL_RESPONSE_NORMALIZATION", + OperatorType::kLocalResponseNormalization), + op); + EXPECT_EQ(op.range, output_toco_op->range); + EXPECT_EQ(op.bias, output_toco_op->bias); + EXPECT_EQ(op.alpha, output_toco_op->alpha); + EXPECT_EQ(op.beta, output_toco_op->beta); +} + +TEST_F(OperatorTest, BuiltinMaxPool) { + MaxPoolOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.kwidth = 480; + op.kheight = 1080; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.kwidth, output_toco_op->kwidth); + EXPECT_EQ(op.kheight, output_toco_op->kheight); +} + +TEST_F(OperatorTest, BuiltinReshape) { + TensorFlowReshapeOperator op; + op.shape = {1, 2, 4, 5, 8}; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); + EXPECT_EQ(op.shape, output_toco_op->shape); +} + +TEST_F(OperatorTest, CustomSoftmax) { + SoftmaxOperator op; + op.beta = 123.1; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SOFTMAX", OperatorType::kSoftmax), op); + EXPECT_EQ(op.beta, output_toco_op->beta); +} + +TEST_F(OperatorTest, BuiltinSpaceToDepth) { + SpaceToDepthOperator op; + op.block_size = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op); + EXPECT_EQ(op.block_size, output_toco_op->block_size); +} + +TEST_F(OperatorTest, CustomSplit) { + TensorFlowSplitOperator op; + op.num_split = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); + EXPECT_EQ(op.num_split, output_toco_op->num_split); +} + +TEST_F(OperatorTest, BuiltinAveragePool) { + AveragePoolOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.kwidth = 480; + op.kheight = 1080; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.kwidth, output_toco_op->kwidth); + EXPECT_EQ(op.kheight, output_toco_op->kheight); +} + +TEST_F(OperatorTest, BuiltinConvolution) { + ConvOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinDepthwiseConvolution) { + DepthwiseConvOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.depth_multiplier = 6; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinL2Norm) { + L2NormalizationOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinMul) { + MulOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, Svdf) { + SvdfOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, TensorFlowUnsupported) { + TensorFlowUnsupportedOperator op; + op.tensorflow_op = "MyCustomUnsupportedOp"; + + ::tensorflow::NodeDef node_def; + auto attr = node_def.mutable_attr(); + (*attr)["float_attr"].set_f(2.0); + (*attr)["str_attr"].set_s("Hello World"); + (*attr)["int_attr"].set_i(17); + (*attr)["bool_attr"].set_b(true); + node_def.SerializeToString(&op.tensorflow_node_def); + + auto output_toco_op = + SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", + OperatorType::kTensorFlowUnsupported), + op); + + ::tensorflow::NodeDef output_node_def; + output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); + const auto& output_attr = output_node_def.attr(); + EXPECT_EQ(2.0, output_attr.at("float_attr").f()); + EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); + EXPECT_EQ(17, output_attr.at("int_attr").i()); + EXPECT_EQ(true, output_attr.at("bool_attr").b()); +} + +TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { + TensorFlowUnsupportedOperator op; + op.tensorflow_op = "MyCustomUnsupportedOp"; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", + OperatorType::kTensorFlowUnsupported), + op); + + ::tensorflow::NodeDef output_node_def; + output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); + EXPECT_TRUE(output_node_def.attr().empty()); +} + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h new file mode 100644 index 0000000000..992b98baca --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ + +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +namespace toco { + +namespace tflite { + +// Simple operators don't have any configuration options and can be trivially +// serialized and deserialized. Note that most of toco's operators will +// likely be supported as builtin operators in TF Lite. Simple (and custom) +// operators are mostly a convenience for the times when tf.mini supports more +// operators than TF Lite. +// +// Template argument T must derive from ::toco::Operator. +template +class SimpleOperator : public BaseOperator { + public: + using BaseOperator::BaseOperator; + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return Options(); + } + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + return std::unique_ptr(new T); + } +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc new file mode 100644 index 0000000000..5b4dbfae24 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +namespace toco { + +namespace tflite { + +namespace { +template +DataBuffer::FlatBufferOffset CopyBuffer( + const Array& array, flatbuffers::FlatBufferBuilder* builder) { + using NativeT = ::toco::DataType; + const auto& src_data = array.GetBuffer().data; + const uint8_t* dst_data = reinterpret_cast(src_data.data()); + auto size = src_data.size() * sizeof(NativeT); + return builder->CreateVector(dst_data, size); +} + +template +void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { + using NativeT = ::toco::DataType; + auto* src_buffer = buffer.data(); + const NativeT* src_data = + reinterpret_cast(src_buffer->data()); + int num_items = src_buffer->size() / sizeof(NativeT); + + std::vector* dst_data = &array->GetMutableBuffer().data; + for (int i = 0; i < num_items; ++i) { + dst_data->push_back(*src_data); + ++src_data; + } +} +} // namespace + +::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) { + switch (array_data_type) { + case ArrayDataType::kFloat: + return ::tflite::TensorType_FLOAT32; + case ArrayDataType::kInt32: + return ::tflite::TensorType_INT32; + case ArrayDataType::kUint8: + return ::tflite::TensorType_UINT8; + default: + // FLOAT32 is filled for unknown data types. + // TODO(ycling): Implement type inference in TF Lite interpreter. + return ::tflite::TensorType_FLOAT32; + } +} + +ArrayDataType DataType::Deserialize(int tensor_type) { + switch (::tflite::TensorType(tensor_type)) { + case ::tflite::TensorType_FLOAT32: + return ArrayDataType::kFloat; + case ::tflite::TensorType_INT32: + return ArrayDataType::kInt32; + case ::tflite::TensorType_UINT8: + return ArrayDataType::kUint8; + default: + LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; + } +} + +flatbuffers::Offset> DataBuffer::Serialize( + const Array& array, flatbuffers::FlatBufferBuilder* builder) { + if (!array.buffer) return 0; // an empty buffer, usually an output. + + switch (array.data_type) { + case ArrayDataType::kFloat: + return CopyBuffer(array, builder); + case ArrayDataType::kInt32: + return CopyBuffer(array, builder); + case ArrayDataType::kUint8: + return CopyBuffer(array, builder); + default: + LOG(FATAL) << "Unhandled array data type."; + } +} + +void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, + const ::tflite::Buffer& buffer, Array* array) { + if (tensor.buffer() == 0) return; // an empty buffer, usually an output. + if (buffer.data() == nullptr) return; // a non-defined buffer. + + switch (tensor.type()) { + case ::tflite::TensorType_FLOAT32: + return CopyBuffer(buffer, array); + case ::tflite::TensorType_INT32: + return CopyBuffer(buffer, array); + case ::tflite::TensorType_UINT8: + return CopyBuffer(buffer, array); + default: + LOG(FATAL) << "Unhandled tensor type."; + } +} + +::tflite::Padding Padding::Serialize(PaddingType padding_type) { + switch (padding_type) { + case PaddingType::kSame: + return ::tflite::Padding_SAME; + case PaddingType::kValid: + return ::tflite::Padding_VALID; + default: + LOG(FATAL) << "Unhandled padding type."; + } +} + +PaddingType Padding::Deserialize(int padding) { + switch (::tflite::Padding(padding)) { + case ::tflite::Padding_SAME: + return PaddingType::kSame; + case ::tflite::Padding_VALID: + return PaddingType::kValid; + default: + LOG(FATAL) << "Unhandled padding."; + } +} + +::tflite::ActivationFunctionType ActivationFunction::Serialize( + FusedActivationFunctionType faf_type) { + switch (faf_type) { + case FusedActivationFunctionType::kNone: + return ::tflite::ActivationFunctionType_NONE; + case FusedActivationFunctionType::kRelu: + return ::tflite::ActivationFunctionType_RELU; + case FusedActivationFunctionType::kRelu6: + return ::tflite::ActivationFunctionType_RELU6; + case FusedActivationFunctionType::kRelu1: + return ::tflite::ActivationFunctionType_RELU1; + default: + LOG(FATAL) << "Unhandled fused activation function type."; + } +} + +FusedActivationFunctionType ActivationFunction::Deserialize( + int activation_function) { + switch (::tflite::ActivationFunctionType(activation_function)) { + case ::tflite::ActivationFunctionType_NONE: + return FusedActivationFunctionType::kNone; + case ::tflite::ActivationFunctionType_RELU: + return FusedActivationFunctionType::kRelu; + case ::tflite::ActivationFunctionType_RELU6: + return FusedActivationFunctionType::kRelu6; + case ::tflite::ActivationFunctionType_RELU1: + return FusedActivationFunctionType::kRelu1; + default: + LOG(FATAL) << "Unhandled fused activation function type."; + } +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/contrib/lite/toco/tflite/types.h new file mode 100644 index 0000000000..f7c5140510 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/types.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ + +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +struct DataType { + static ::tflite::TensorType Serialize(ArrayDataType array_data_type); + static ArrayDataType Deserialize(int tensor_type); +}; + +struct DataBuffer { + using FlatBufferOffset = flatbuffers::Offset>; + + // Build the flatbuffer representation of a toco's Array and return the + // corresponding offset into the flatbuffer. Note that data from the array + // will be copied into the flatbuffer. + static FlatBufferOffset Serialize(const Array& array, + flatbuffers::FlatBufferBuilder* builder); + // Copy data from the given tensor into toco's Array. + static void Deserialize(const ::tflite::Tensor& tensor, + const ::tflite::Buffer& buffer, Array* array); +}; + +struct Padding { + static ::tflite::Padding Serialize(PaddingType padding_type); + static PaddingType Deserialize(int padding); +}; + +struct ActivationFunction { + static ::tflite::ActivationFunctionType Serialize( + FusedActivationFunctionType faf_type); + static FusedActivationFunctionType Deserialize(int activation_function); +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc new file mode 100644 index 0000000000..174b78f3e6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -0,0 +1,191 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +#include +#include + +namespace toco { + +namespace tflite { +namespace { + +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; + +// These are types that exist in TF Mini but don't have a correspondence +// in TF Lite. +static const ArrayDataType kUnsupportedTocoTypes[] = { + ArrayDataType::kNone, ArrayDataType::kBool, ArrayDataType::kInt64}; + +// These are TF Lite types for which there is no correspondence in TF Mini. +static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = { + ::tflite::TensorType_FLOAT16}; + +// A little helper to match flatbuffer offsets. +MATCHER_P(HasOffset, value, "") { return arg.o == value; } + +// Helper function that creates an array, writes it into a flatbuffer, and then +// reads it back in. +template +Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType> items) { + // NOTE: This test does not construct the full buffers list. Since + // Deserialize normally takes a buffer, we need to synthesize one and provide + // an index that is non-zero so the buffer is not assumed to be emtpy. + Array src; + src.data_type = T; + src.GetMutableBuffer().data = items; + + Array result; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T), + /*buffer*/ 1)); // Can't use 0 which means empty. + flatbuffers::FlatBufferBuilder buffer_builder; + Offset> data_buffer = + DataBuffer::Serialize(src, &buffer_builder); + buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer)); + + auto* tensor = + flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); + auto* buffer = + flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer()); + DataBuffer::Deserialize(*tensor, *buffer, &result); + return result; +} + +TEST(DataType, SupportedTypes) { + std::vector> testdata = { + {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, + {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, + {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}}; + for (auto x : testdata) { + EXPECT_EQ(x.second, DataType::Serialize(x.first)); + EXPECT_EQ(x.first, DataType::Deserialize(x.second)); + } +} + +TEST(DataType, UnsupportedTypes) { + for (::tflite::TensorType t : kUnsupportedTfLiteTypes) { + EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type."); + } + + // Unsupported types are all serialized as FLOAT32 currently. + for (ArrayDataType t : kUnsupportedTocoTypes) { + EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t)); + } +} + +TEST(DataBuffer, EmptyBuffers) { + flatbuffers::FlatBufferBuilder builder; + Array array; + EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0)); + + builder.Finish(::tflite::CreateTensor(builder)); + auto* tensor = + flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); + flatbuffers::FlatBufferBuilder buffer_builder; + Offset> v = buffer_builder.CreateVector({}); + buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v)); + auto* buffer = + flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer()); + + DataBuffer::Deserialize(*tensor, *buffer, &array); + EXPECT_EQ(nullptr, array.buffer); +} + +TEST(DataBuffer, UnsupportedTypes) { + for (ArrayDataType t : kUnsupportedTocoTypes) { + flatbuffers::FlatBufferBuilder builder; + Array array; + array.data_type = t; + array.GetMutableBuffer(); // This is OK. + EXPECT_DEATH(DataBuffer::Serialize(array, &builder), + "Unhandled array data type."); + } + + for (::tflite::TensorType t : kUnsupportedTfLiteTypes) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1)); + flatbuffers::FlatBufferBuilder buffer_builder; + Offset> v = buffer_builder.CreateVector({1}); + buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v)); + auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>( + buffer_builder.GetBufferPointer()); + auto* tensor = + flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); + Array array; + EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array), + "Unhandled tensor type."); + } +} + +TEST(DataBuffer, Float) { + Array recovered = ToFlatBufferAndBack({1.0f, 2.0f}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1.0f, 2.0f)); +} + +TEST(DataBuffer, Uint8) { + Array recovered = ToFlatBufferAndBack({127, 244}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(127, 244)); +} + +TEST(DataBuffer, Int32) { + Array recovered = ToFlatBufferAndBack({1, 1 << 30}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1 << 30)); +} + +TEST(Padding, All) { + EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); + EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); + + EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid)); + EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID)); + + EXPECT_DEATH(Padding::Serialize(static_cast(10000)), + "Unhandled padding type."); + EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding."); +} + +TEST(ActivationFunction, All) { + std::vector< + std::pair> + testdata = {{FusedActivationFunctionType::kNone, + ::tflite::ActivationFunctionType_NONE}, + {FusedActivationFunctionType::kRelu, + ::tflite::ActivationFunctionType_RELU}, + {FusedActivationFunctionType::kRelu6, + ::tflite::ActivationFunctionType_RELU6}, + {FusedActivationFunctionType::kRelu1, + ::tflite::ActivationFunctionType_RELU1}}; + for (auto x : testdata) { + EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first)); + EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second)); + } + + EXPECT_DEATH(ActivationFunction::Serialize( + static_cast(10000)), + "Unhandled fused activation function type."); + EXPECT_DEATH(ActivationFunction::Deserialize(10000), + "Unhandled fused activation function type."); +} + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc new file mode 100644 index 0000000000..f01ec0ec61 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_tooling.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/platform/logging.h" + +#ifndef CHECK_OK +#define CHECK_OK(val) CHECK_EQ((val).ok(), true) +#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true) +#endif + +namespace toco { +namespace { + +#define QCHECK_REQUIRE_TOCO_FLAG(arg) \ + QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg; + +void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + const TocoFlags& toco_flags) { + port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + + QCHECK_REQUIRE_TOCO_FLAG(input_file) + QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(), + port::file::Defaults())) + << "Specified input_file does not exist: " + << parsed_toco_flags.input_file.value(); + QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(), + port::file::Defaults())) + << "Specified input_file exists, but is not readable: " + << parsed_toco_flags.input_file.value(); + + QCHECK_REQUIRE_TOCO_FLAG(output_file); + QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value())) + << "parsed_toco_flags.input_file.value() output_file is not writable: " + << parsed_toco_flags.output_file.value(); +} + +void ToolMain(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags) { + ModelFlags model_flags; + ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); + + TocoFlags toco_flags; + ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); + + CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags); + + string input_file_contents; + CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(), + &input_file_contents, + port::file::Defaults())); + std::unique_ptr model = + Import(toco_flags, model_flags, input_file_contents); + Transform(toco_flags, model.get()); + string output_file_contents; + Export(toco_flags, *model, toco_flags.allow_custom_ops(), + &output_file_contents); + CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(), + output_file_contents, + port::file::Defaults())); +} + +} // namespace +} // namespace toco + +int main(int argc, char** argv) { + toco::string msg; + toco::ParsedTocoFlags parsed_toco_flags; + toco::ParsedModelFlags parsed_model_flags; + + // If no args were specified, give a help string to be helpful. + int* effective_argc = &argc; + char** effective_argv = argv; + if (argc == 1) { + // No arguments, so manufacture help argv. + static int dummy_argc = 2; + static char* dummy_argv[] = {argv[0], const_cast("--help")}; + effective_argc = &dummy_argc; + effective_argv = dummy_argv; + } + + // Parse toco flags and command flags in sequence, each one strips off args, + // giving InitGoogle a chance to handle all remaining arguments. + bool toco_success = toco::ParseTocoFlagsFromCommandLineFlags( + effective_argc, effective_argv, &msg, &parsed_toco_flags); + bool model_success = toco::ParseModelFlagsFromCommandLineFlags( + effective_argc, effective_argv, &msg, &parsed_model_flags); + if (!toco_success || !model_success || !msg.empty()) { + fprintf(stderr, "%s", msg.c_str()); + fflush(stderr); + return 1; + } + toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true); + toco::ToolMain(parsed_toco_flags, parsed_model_flags); +} diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc new file mode 100644 index 0000000000..d43c3b4a8e --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace toco { + +bool ParseTocoFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedTocoFlags* parsed_toco_flags_ptr) { + using tensorflow::Flag; + ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr; + std::vector flags = { + Flag("input_file", parsed_flags.input_file.bind(), + parsed_flags.input_file.default_value(), + "Input file (model of any supported format). For Protobuf " + "formats, both text and binary are supported regardless of file " + "extension."), + Flag("output_file", parsed_flags.output_file.bind(), + parsed_flags.output_file.default_value(), + "Output file. " + "For Protobuf formats, the binary format will be used."), + Flag("input_format", parsed_flags.input_format.bind(), + parsed_flags.input_format.default_value(), + "Input file format. One of: tensorflow_graphdef, "), + Flag("output_format", parsed_flags.output_format.bind(), + parsed_flags.output_format.default_value(), "Output file format."), + Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), + parsed_flags.default_ranges_min.default_value(), + "If defined, will be used as the default value for the min bound " + "of min/max ranges used for quantization."), + Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(), + parsed_flags.default_ranges_max.default_value(), + "If defined, will be used as the default value for the max bound " + "of min/max ranges used for quantization."), + Flag("input_type", parsed_flags.input_type.bind(), + parsed_flags.input_type.default_value(), + "Data type of the input array in the " + "output file. "), + Flag("input_types", parsed_flags.input_types.bind(), + parsed_flags.input_types.default_value(), + "Data types of the input arrays in the " + "output file. " + "Comma-separated list matching the enumeration order of " + "input_arrays."), + Flag("inference_type", parsed_flags.inference_type.bind(), + parsed_flags.inference_type.default_value(), + "Data type, in the output file, of internal and output arrays " + "that are FLOAT in the input file. Thus, the value FLOAT means " + "keep doing floating-point inference, while the value " + "QUANTIZED_UINT8 means replace all internal floating-point " + "arithmetic by integer arithmetic producing 8-bit integer " + "activations instead of float activations --- which we call " + "\'quantized inference\'."), + Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(), + parsed_flags.drop_fake_quant.default_value(), + "Ignore and discard FakeQuant nodes. For instance, that can be used " + "to " + "generate plain float code without fake-quantization from a " + "quantized " + "graph."), + Flag( + "reorder_across_fake_quant", + parsed_flags.reorder_across_fake_quant.bind(), + parsed_flags.reorder_across_fake_quant.default_value(), + "Normally, FakeQuant nodes must be strict boundaries for graph " + "transformations, in order to ensure that quantized inference has " + "the " + "exact same arithmetic behavior as quantized training --- which is " + "the " + "whole point of quantized training and of FakeQuant nodes in the " + "first " + "place. However, that entails subtle requirements on where exactly " + "FakeQuant nodes must be placed in the graph. Some quantized graphs " + "have FakeQuant nodes at unexpected locations, that prevent graph " + "transformations that are necessary in order to generate inference " + "code for these graphs. Such graphs should be fixed, but as a " + "temporary work-around, setting this reorder_across_fake_quant flag " + "allows toco to perform necessary graph transformaitons on them, " + "at the cost of no longer faithfully matching inference and training " + "arithmetic."), + Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(), + parsed_flags.allow_custom_ops.default_value(), + "If true, allow TOCO to create TF Lite Custom operators for all the" + "unsupported Tensorflow ops."), + }; + bool asked_for_help = + *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); + if (asked_for_help) { + *msg += tensorflow::Flags::Usage(argv[0], flags); + return false; + } else { + return tensorflow::Flags::Parse(argc, argv, flags); + } +} + +void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, + TocoFlags* toco_flags) { + namespace port = toco::port; + port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + + enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified }; + +#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \ + do { \ + if (requirement == FlagRequirement::kMustBeSpecified) { \ + QCHECK(parsed_toco_flags.name.specified()) \ + << "Missing required flag: " << #name; \ + } \ + if (requirement == FlagRequirement::kMustNotBeSpecified) { \ + QCHECK(!parsed_toco_flags.name.specified()) \ + << "Given other flags, this flag should not have been specified: " \ + << #name; \ + } \ + } while (false) + +#define READ_TOCO_FLAG(name, requirement) \ + ENFORCE_FLAG_REQUIREMENT(name, requirement); \ + do { \ + if (parsed_toco_flags.name.specified()) { \ + toco_flags->set_##name(parsed_toco_flags.name.value()); \ + } \ + } while (false) + +#define PARSE_TOCO_FLAG(Type, name, requirement) \ + ENFORCE_FLAG_REQUIREMENT(name, requirement); \ + do { \ + if (parsed_toco_flags.name.specified()) { \ + Type x; \ + QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \ + << "Unrecognized " << #Type << " value " \ + << parsed_toco_flags.name.value(); \ + toco_flags->set_##name(x); \ + } \ + } while (false) + + PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified); + PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified); + FlagRequirement tflite_flags_requirement = + toco_flags->output_format() == TFLITE + ? FlagRequirement::kMustBeSpecified + : FlagRequirement::kMustNotBeSpecified; + PARSE_TOCO_FLAG(IODataType, inference_type, tflite_flags_requirement); + READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); + READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone); + READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone); + READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); + READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); + +#undef READ_TOCO_FLAG +#undef PARSE_TOCO_FLAG + + const bool input_type_specified = parsed_toco_flags.input_type.specified(); + const bool input_types_specified = parsed_toco_flags.input_types.specified(); + if (toco_flags->output_format() == TFLITE) { + QCHECK(input_type_specified || input_types_specified) + << "When output_format=TFLITE, either input_type or input_types needs " + "to be specified."; + } else { + QCHECK(!input_type_specified && !input_types_specified) + << "With this output_format, neither input_type nor input_types must " + "be specified."; + } + QCHECK(!(input_type_specified && input_types_specified)) + << "input_type and input_types are mutually exclusive"; + if (input_type_specified) { + IODataType type; + QCHECK(IODataType_Parse(parsed_toco_flags.input_type.value(), &type)) + << "Unrecognized input_type: " << parsed_toco_flags.input_type.value(); + toco_flags->add_input_types(type); + } + if (input_types_specified) { + std::vector input_types = + absl::StrSplit(parsed_toco_flags.input_types.value(), ','); + for (const string& t : input_types) { + IODataType type; + QCHECK(IODataType_Parse(t, &type)) + << "Unrecognized input_types value " << t + << " in input_types=" << parsed_toco_flags.input_types.value(); + toco_flags->add_input_types(type); + } + } +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h new file mode 100644 index 0000000000..155a6fea87 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ + +#include +#include +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" + +namespace toco { +// Parse and remove arguments handled from toco. Returns true if parsing +// is successful. msg has the usage string if there was an error or +// "--help" was specified +bool ParseTocoFlagsFromCommandLineFlags(int* argc, char* argv[], string* msg, + ParsedTocoFlags* parsed_toco_flags_ptr); +// Populate the TocoFlags proto with parsed_toco_flags data. +void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, + TocoFlags* toco_flags); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto new file mode 100644 index 0000000000..fd7c29fdc7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -0,0 +1,126 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto2"; +package toco; + +// Supported I/O file formats. Some formats may be input-only or output-only. +enum FileFormat { + FILE_FORMAT_UNKNOWN = 0; + + // GraphDef, third_party/tensorflow/core/framework/graph.proto + TENSORFLOW_GRAPHDEF = 1; + + // Tensorflow's mobile inference model. + // third_party/tensorflow/contrib/tflite/schema.fbs + TFLITE = 2; + + // GraphViz + // Export-only. + GRAPHVIZ_DOT = 3; +} + +// IODataType describes the numeric data types to be used by the output format. +// See input_type and inference_type below. +enum IODataType { + IO_DATA_TYPE_UNKNOWN = 0; + + // Float32, not quantized + FLOAT = 1; + + // Uint8, quantized + QUANTIZED_UINT8 = 2; + + // Int32, not quantized + INT32 = 3; + + // Int64, not quantized + INT64 = 4; + + // String, not quantized + STRING = 5; +} + +// TocoFlags encodes extra parameters that drive tooling operations, that +// are not normally encoded in model files and in general may not be thought +// of as properties of models, instead describing how models are to be +// processed in the context of the present tooling job. +// Next Id: 11 +message TocoFlags { + // Input file format + optional FileFormat input_format = 1; + + // Output file format + optional FileFormat output_format = 2; + + // Numeric data types of the input arrays in the output format. + // This controls what input types the output file will be expecting. + // This is not a description of the input types of the input file. + // For example, the input file may have a float input placeholder, + // but we may want to generate a quantized TFLite file from it, + // or a float TFLite file taking a quantized input. + // + // The length of this list should match the length of the input_arrays + // list in ModelFlags. + repeated IODataType input_types = 9; + + // Numeric data type of the internal activations array and output array. + // + // As a matter of implementation detail, most model + // parameter arrays (weights, etc) will tend to also use this data type. + // Not all will, though: for instance, bias vectors will typically + // get quantized as int32 when weights and activations get quantized as uint8. + optional IODataType inference_type = 4; + + // default_ranges_min and default_ranges_max are helpers to experiment + // with quantization of models. Normally, quantization requires the input + // model to have (min, max) range information for every activations array. + // This is needed in order to know how to quantize arrays and still achieve + // satisfactory accuracy. However, in some circumstances one would just like + // to estimate the performance of quantized inference, without caring about + // accuracy. That is what default_ranges_min and default_ranges_max are for: + // when specified, they will be used as default (min, max) range boundaries + // for all activation arrays that lack (min, max) range information, thus + // allowing for quantization to proceed. + // + // It should be clear from the above explanation that these parameters are + // for experimentation purposes only and should not be used in production: + // they make it easy to quantize models, but the resulting quantized model + // will be inaccurate. + optional float default_ranges_min = 5; + optional float default_ranges_max = 6; + + // Ignore and discard FakeQuant nodes. For instance, that can be used to + // generate plain float code without fake-quantization from a quantized + // graph. + optional bool drop_fake_quant = 7; + + // Normally, FakeQuant nodes must be strict boundaries for graph + // transformations, in order to ensure that quantized inference has the + // exact same arithmetic behavior as quantized training --- which is the + // whole point of quantized training and of FakeQuant nodes in the first + // place. However, that entails subtle requirements on where exactly + // FakeQuant nodes must be placed in the graph. Some quantized graphs + // have FakeQuant nodes at unexpected locations, that prevent graph + // transformations that are necessary in order to generate inference + // code for these graphs. Such graphs should be fixed, but as a + // temporary work-around, setting this reorder_across_fake_quant flag + // allows toco to perform necessary graph transformaitons on them, + // at the cost of no longer faithfully matching inference and training + // arithmetic. + optional bool reorder_across_fake_quant = 8; + + // If true, allow TOCO to create TF Lite Custom operators for all the + // unsupported Tensorflow ops. + optional bool allow_custom_ops = 10; +} diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc new file mode 100644 index 0000000000..4e98e7081d --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" + +namespace toco { +GraphVizDumpOptions* GraphVizDumpOptions::singleton() { + static auto* ptr = new GraphVizDumpOptions; + return ptr; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h new file mode 100644 index 0000000000..ae0541f62b --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ + +#include + +namespace toco { + +// Global data for determining whether to output graph viz format from toco. +struct GraphVizDumpOptions { + std::string graphviz_first_array; + std::string graphviz_last_array; + std::string dump_graphviz; + bool dump_graphviz_video = false; + + static GraphVizDumpOptions* singleton(); +}; + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc new file mode 100644 index 0000000000..a1c8696cd0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -0,0 +1,227 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace port { +void CopyToBuffer(const string& src, char* dest) { + memcpy(dest, src.data(), src.size()); +} + +#ifdef PLATFORM_GOOGLE +void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); } +#endif +} // namespace port +} // namespace toco + +#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__) + +// Wrap Google file operations. + +#include "base/init_google.h" +#include "file/base/file.h" +#include "file/base/filesystem.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "file/base/path.h" + +namespace toco { +namespace port { + +void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { + ::InitGoogle(usage, argc, argv, remove_flags); +} + +void CheckInitGoogleIsDone(const char* message) { + ::CheckInitGoogleIsDone(message); +} + +namespace file { + +// Conversion to our wrapper Status. +Status ToStatus(const ::util::Status& uts) { + return Status(uts.ok(), uts.error_message()); +} + +// Conversion to our wrapper Options. +toco::port::file::Options ToOptions(const ::file::Options& options) { + CHECK_EQ(&options, &::file::Defaults()); + return Options(); +} + +Status Writable(const string& filename) { + File* f = nullptr; + const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); + if (f) { + QCHECK_OK(f->Close(::file::Defaults())); + } + return ToStatus(status); +} + +Status Readable(const string& filename, const file::Options& options) { + return ToStatus(::file::Readable(filename, ::file::Defaults())); +} + +Status Exists(const string& filename, const file::Options& options) { + auto status = ::file::Exists(filename, ::file::Defaults()); + return ToStatus(status); +} + +Status GetContents(const string& filename, string* contents, + const file::Options& options) { + return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); +} + +Status SetContents(const string& filename, const string& contents, + const file::Options& options) { + return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); +} + +string JoinPath(const string& a, const string& b) { + return ::file::JoinPath(a, b); +} + +} // namespace file +} // namespace port +} // namespace toco + +#else // (__APPLE__ || __ANDROID__) + +#include +#include +#include +#include +#include + +#if defined(PLATFORM_GOOGLE) +#include "base/commandlineflags.h" +#endif + +namespace toco { +namespace port { + +static bool port_initialized = false; + +void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { + if (!port_initialized) { +#if defined(PLATFORM_GOOGLE) + ParseCommandLineFlags(argc, argv, remove_flags); +#endif + port_initialized = true; + } +} + +void CheckInitGoogleIsDone(const char* message) { + CHECK(port_initialized) << message; +} + +namespace file { + +Status Writable(const string& filename) { + FILE* f = fopen(filename.c_str(), "w"); + if (f) { + fclose(f); + return Status(true, ""); + } + return Status(false, "not writable"); +} + +Status Readable(const string& filename, const file::Options& options) { + FILE* f = fopen(filename.c_str(), "r"); + if (f) { + fclose(f); + return Status(true, ""); + } + return Status(false, "not readable"); +} + +Status Exists(const string& filename, const file::Options& options) { + struct stat statbuf; + int ret = stat(filename.c_str(), &statbuf); + return Status(ret != -1, ""); +} + +Status GetContents(const string& path, string* output, + const file::Options& options) { + output->clear(); + + int fd = open(path.c_str(), O_RDONLY); + if (fd == -1) { + return Status(false, "can't open() for read"); + } + + // Direct read, for speed. + const int kBufSize = 1 << 16; + char buffer[kBufSize]; + while (true) { + int size = read(fd, buffer, kBufSize); + if (size == 0) { + // Done. + close(fd); + return Status(true, ""); + } else if (size == -1) { + // Error. + close(fd); + return Status(false, "error during read()"); + } else { + output->append(buffer, size); + } + } + + CHECK(0); + return Status(false, "internal error"); +} + +Status SetContents(const string& filename, const string& contents, + const file::Options& options) { + int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); + if (fd == -1) { + return Status(false, "can't open() for write"); + } + + size_t i = 0; + while (i < contents.size()) { + size_t to_write = contents.size() - i; + ssize_t written = write(fd, &contents[i], to_write); + if (written == -1) { + close(fd); + return Status(false, "write() error"); + } + i += written; + } + close(fd); + + return Status(true, ""); +} + +string JoinPath(const string& base, const string& filename) { + if (base.empty()) return filename; + string base_fixed = base; + if (!base_fixed.empty() && base_fixed.back() == '/') base_fixed.pop_back(); + string filename_fixed = filename; + if (!filename_fixed.empty() && filename_fixed.front() == '/') + filename_fixed.erase(0, 1); + return base_fixed + "/" + filename_fixed; +} + +} // namespace file +} // namespace port +} // namespace toco + +#endif // (__APPLE || __ANDROID__) diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h new file mode 100644 index 0000000000..b5cb7a11e7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ + +// Portability layer for toco tool. Mainly, abstract filesystem access so we +// can build and use on google internal environments and on OSX. + +#include +#include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/platform.h" +#if defined(PLATFORM_GOOGLE) +#include "absl/strings/cord.h" +#endif // PLATFORM_GOOGLE + +#ifdef PLATFORM_GOOGLE +#define TFLITE_PROTO_NS proto2 +#else +#define TFLITE_PROTO_NS google::protobuf +#endif + +namespace toco { +namespace port { + +class Status { + public: + Status() {} + + Status(bool ok, const string& message) : ok_(ok), message_(message) {} + + bool ok() const { return ok_; } + + const string error_message() const { return message_; } + + private: + bool ok_ = false; + string message_; +}; + +void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags); +void CheckInitGoogleIsDone(const char* message); + +namespace file { +class Options {}; +inline Options Defaults() { + Options o; + return o; +} +Status GetContents(const string& filename, string* contents, + const Options& options); +Status SetContents(const string& filename, const string& contents, + const Options& options); +string JoinPath(const string& base, const string& filename); +Status Writable(const string& filename); +Status Readable(const string& filename, const Options& options); +Status Exists(const string& filename, const Options& options); +} // namespace file + +// Copy `src` string to `dest`. User must ensure `dest` has enough space. +#if defined(PLATFORM_GOOGLE) +void CopyToBuffer(const ::Cord& src, char* dest); +#endif // PLATFORM_GOOGLE +void CopyToBuffer(const string& src, char* dest); +} // namespace port +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port_test.cc b/tensorflow/contrib/lite/toco/toco_port_test.cc new file mode 100644 index 0000000000..650a617aeb --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_port_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" + +#include +#include + +namespace toco { +namespace port { +namespace { + +#ifdef PLATFORM_GOOGLE +#define TFLITE_PREFIX "third_party/tensorflow/contrib/lite/" +#else +#define TFLITE_PREFIX "tensorflow/contrib/lite/" +#endif + +TEST(TocoPortTest, Exists) { + EXPECT_TRUE( + file::Exists(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults()) + .ok()); + + EXPECT_FALSE( + file::Exists("non-existent_file_asldjflasdjf", file::Defaults()).ok()); +} + +TEST(TocoPortTest, Readable) { + EXPECT_TRUE( + file::Readable(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults()) + .ok()); + + EXPECT_FALSE( + file::Readable("non-existent_file_asldjflasdjf", file::Defaults()).ok()); +} + +TEST(TocoPortTest, JoinPath) { + EXPECT_EQ("part1/part2", file::JoinPath("part1", "part2")); + EXPECT_EQ("part1/part2", file::JoinPath("part1/", "part2")); + EXPECT_EQ("part1/part2", file::JoinPath("part1", "/part2")); + EXPECT_EQ("part1/part2", file::JoinPath("part1/", "/part2")); +} + +} // namespace +} // namespace port +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc new file mode 100644 index 0000000000..232538a841 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/toco_tooling.h" + +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" +#include "tensorflow/contrib/lite/toco/export_tensorflow.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/import_tensorflow.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tflite/export.h" +#include "tensorflow/contrib/lite/toco/tflite/import.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace { +// CHECK-fails if the model contains a kTensorFlowUnsupported operation. +void CheckUnsupportedOperations(const Model& model) { + std::set unsupported_ops; + for (auto& op : model.operators) { + if (op->type == OperatorType::kTensorFlowUnsupported) { + unsupported_ops.insert( + static_cast(op.get()) + ->tensorflow_op); + } + } + QCHECK(unsupported_ops.empty()) + << "These unsupported ops were not removed by graph transformations: " + << absl::StrJoin(unsupported_ops, ", "); +} + +void MakeGeneralGraphTransformationsSet( + GraphTransformationsSet* transformations) { + CHECK(transformations->empty()); + transformations->Add(new ResolveReshapeAttributes); + transformations->Add(new PropagateArrayDataTypes); + transformations->Add(new PropagateFixedSizes); + transformations->Add(new RemoveTensorFlowAssert); + transformations->Add(new RemoveTensorFlowIdentity); + transformations->Add(new RemoveTrivialConcatenation); + transformations->Add(new RemoveTrivialConcatenationInput); + transformations->Add(new RemoveUnusedOp); + transformations->Add(new EnsureBiasVectors); + transformations->Add(new ResolveReorderAxes); + transformations->Add(new ResolveTensorFlowMatMul); + transformations->Add(new FuseBinaryIntoPrecedingAffine); + transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new ResolveBatchNormalization); + transformations->Add(new ResolveConstantBinaryOperator); + transformations->Add(new ResolveConstantUnaryOperator); + transformations->Add(new ResolveTensorFlowMerge); + transformations->Add(new ResolveTensorFlowSqueeze); + transformations->Add(new ResolveTensorFlowSwitch); + transformations->Add(new ResolveTensorFlowTile); + transformations->Add(new ResolveTensorFlowConcat); + transformations->Add(new IdentifyL2Normalization); + transformations->Add(new IdentifyL2Pool); + transformations->Add(new IdentifyRelu1); + transformations->Add(new RemoveTrivialBinaryOperator); + transformations->Add(new ReadFakeQuantMinMax); + transformations->Add(new ResolvePadAttributes); + transformations->Add(new ResolveStridedSliceAttributes); + transformations->Add(new ResolveSliceAttributes); + transformations->Add(new ResolveMeanAttributes); + transformations->Add(new ResolveConstantTensorFlowShape); + transformations->Add(new MakeInitialDequantizeOperator); +} + +void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) { + const bool output_is_tflite = toco_flags.output_format() == TFLITE; + + if (output_is_tflite) { + if (!toco_flags.input_types().empty()) { + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + int input_types_index = toco_flags.input_types_size() == 1 ? 0 : i; + const auto input_type = toco_flags.input_types(input_types_index); + ArrayDataType final_data_type = ArrayDataType::kNone; + switch (input_type) { + case FLOAT: + final_data_type = ArrayDataType::kFloat; + break; + case QUANTIZED_UINT8: + final_data_type = ArrayDataType::kUint8; + break; + case INT32: + final_data_type = ArrayDataType::kInt32; + break; + case INT64: + final_data_type = ArrayDataType::kInt64; + break; + default: + LOG(FATAL) << "Unknown data type"; + } + model->arrays[model->flags.input_arrays(i).name()]->final_data_type = + final_data_type; + } + } + } else { + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + model->arrays[model->flags.input_arrays(i).name()]->final_data_type = + ArrayDataType::kFloat; + } + } +} + +} // namespace + +std::unique_ptr Import(const TocoFlags& toco_flags, + const ModelFlags& model_flags, + const string& input_file_contents) { + std::unique_ptr model; + switch (toco_flags.input_format()) { + case TENSORFLOW_GRAPHDEF: + model = ImportTensorFlowGraphDef(model_flags, input_file_contents); + break; + case TFLITE: + model = toco::tflite::Import(model_flags, input_file_contents); + ResolveModelFlags(model_flags, model.get()); + CheckInvariants(*model); + break; + default: + LOG(FATAL) << "Unhandled input_format"; + } + + LogDump(kLogLevelModelChanged, "AT IMPORT", *model); + + return model; +} + +void Transform(const TocoFlags& toco_flags, Model* model) { + const FileFormat output_format = toco_flags.output_format(); + const IODataType inference_type = toco_flags.inference_type(); + + const bool output_is_tflite = output_format == TFLITE; + + const bool output_is_tflite_quantized = + output_is_tflite && inference_type == QUANTIZED_UINT8; + + if (output_is_tflite) { + QCHECK(toco_flags.input_types_size() == 1 || + toco_flags.input_types_size() == model->flags.input_arrays_size()) + << "Mismatched numbers of input_arrays and input_types"; + } + + if (output_is_tflite_quantized) { + for (const auto& input_type : toco_flags.input_types()) { + QCHECK_NE(input_type, FLOAT) + << "Quantized inference is not allowed with float inputs."; + } + } + + SetArrayFinalDataTypes(toco_flags, model); + + GraphTransformationsSet transformations; + MakeGeneralGraphTransformationsSet(&transformations); + auto* remove_trivial_reshape = new RemoveTrivialReshape; + transformations.Add(remove_trivial_reshape); + if (output_format == TFLITE) { + transformations.Add(new FuseActivationFunctions); + } else { + transformations.Add(new UnfuseActivationFunctions); + } + if (output_format != TENSORFLOW_GRAPHDEF) { + transformations.Add(new ResolveConstantFakeQuant); + } + if (toco_flags.drop_fake_quant()) { + transformations.Add(new DropFakeQuant); + } else { + // See the doc for --reorder_across_fake_quant: that flag is needed to + // support some existing models, e.g. WordLens, that have FakeQuant + // nodes in the wrong places. + // We currently unconditionally enable that behavior when the output + // format is DarwiNN because the DarwiNN test code does not make it + // easy to pass a new toco flag. Once that is resolved on the DarwiNN + // tests side, the special-casing of DarwiNN here can go away. + // TODO(benoitjacob): so drop it when we can. + if ((output_is_tflite_quantized && + toco_flags.reorder_across_fake_quant())) { + transformations.Add(new DropFakeQuant); + } + } + transformations.Add(new ConvertPureConvToDepthwise); + // TFLite export does not yet support fused LSTM cell. + if (output_format == TENSORFLOW_GRAPHDEF) { + transformations.Add(new IdentifyLstmCell); + } + transformations.Add(new ResolveConstantConcatenation); + RunGraphTransformations(model, "general graph transformations", + transformations); + if (output_is_tflite_quantized) { + RunGraphTransformations(model, "pre-quantization graph transformations", + {new HardcodeMinMax, new DropFakeQuant}); + } + + if (output_is_tflite_quantized) { + if (toco_flags.has_default_ranges_min() && + toco_flags.has_default_ranges_max()) { + UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(), + toco_flags.default_ranges_max()); + } + CheckIsReadyForQuantization(*model); + RunGraphTransformations( + model, "quantization graph transformations", + {new Quantize, new RemoveTrivialQuantizedActivationFunc, + new RemoveFinalDequantizeOp}); + } else { + GraphTransformationsSet dequantization_transformations{new Dequantize}; + // Dequantize creates FakeQuant nodes. We may want to discard + // those immediately. + if (toco_flags.drop_fake_quant()) { + dequantization_transformations.Add(new DropFakeQuant); + } + + RunGraphTransformations(model, "dequantization graph transformations", + dequantization_transformations); + } + + LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); + + if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { + // By now there shouldn't be any unsupported ops when exporting to + // TensorFlow GraphDef. + CheckUnsupportedOperations(*model); + } + + if (output_is_tflite) { + AllocateTransientArrays(model, kDefaultTransientDataAlignment); + LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model); + } + + CheckModelCounts(*model); + CheckFinalDataTypesSatisfied(*model); + + int64 ops_count; + if (EstimateArithmeticOpsCount(*model, &ops_count)) { + LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count + << " billion (note that a multiply-add is counted as 2 ops)."; + } +} + +void Export(const TocoFlags& toco_flags, const Model& model, + bool allow_custom_ops, string* output_file_contents) { + switch (toco_flags.output_format()) { + case TENSORFLOW_GRAPHDEF: + ExportTensorFlowGraphDef(model, output_file_contents); + break; + case TFLITE: + toco::tflite::Export(model, allow_custom_ops, output_file_contents); + break; + case GRAPHVIZ_DOT: + DumpGraphviz(model, output_file_contents); + break; + default: + LOG(FATAL) << "Unhandled output_format"; + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/contrib/lite/toco/toco_tooling.h new file mode 100644 index 0000000000..9c5a93a211 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_tooling.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" + +namespace toco { + +// Imports the input file into a Model object. +std::unique_ptr Import(const TocoFlags& toco_flags, + const ModelFlags& model_flags, + const string& input_file_contents); + +// Transforms a Model. The resulting Model is ready to be passed +// to Export with the exact same toco_flags. +void Transform(const TocoFlags& toco_flags, Model* model); + +// Exports the Model, which must be of the 'lowered' form returned by +// Transform, to a file of the format given by +// toco_flags.output_format(). +void Export(const TocoFlags& toco_flags, const Model& model, + bool allow_custom_ops, string* output_file_contents); + +// This if for backward-compatibility with internal tools. +inline void Export(const TocoFlags& toco_flags, const Model& model, + string* output_file_contents) { + Export(toco_flags, model, true, output_file_contents); +} + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h new file mode 100644 index 0000000000..ad42497ada --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_types.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ + +#include +#include "tensorflow/core/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES) +#include "tensorflow/core/platform/google/integral_types.h" +#else +#include "tensorflow/core/platform/default/integral_types.h" +#endif + +namespace toco { +#ifdef PLATFORM_GOOGLE +using ::string; +#else +using std::string; +#endif + +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc new file mode 100644 index 0000000000..bcbfed62d3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -0,0 +1,1552 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" + + +namespace toco { + +string LogName(const Operator& op) { + const string& opname = HelpfulOperatorTypeName(op); + if (op.outputs.empty()) { + return toco::port::StringF("{%s operator}", opname); + } else { + return toco::port::StringF("{%s operator with output %s}", opname, + op.outputs[0]); + } +} + +bool IsInputArray(const Model& model, const string& name) { + for (const auto& input_array : model.flags.input_arrays()) { + if (input_array.name() == name) { + return true; + } + } + return false; +} + +bool IsArrayConsumed(const Model& model, const string& name) { + if (GetOpWithInput(model, name)) { + return true; + } + for (const string& model_output : model.flags.output_arrays()) { + if (model_output == name) { + return true; + } + } + for (const auto& rnn_state : model.flags.rnn_states()) { + if (rnn_state.back_edge_source_array() == name) { + return true; + } + } + return false; +} + +int CountTrueOutputs(const Model& model, const Operator& op) { + int count = 0; + for (const string& output : op.outputs) { + if (IsArrayConsumed(model, output)) { + ++count; + } + } + return count; +} + +int CountOpsWithInput(const Model& model, const string& array_name) { + int count = 0; + for (const auto& op : model.operators) { + for (auto& input : op->inputs) { + if (input == array_name) { + count++; + } + } + } + return count; +} + +bool DeleteArrayIfUnused(const string& array_name, Model* model) { + if (CountOpsWithInput(*model, array_name) == 0) { + model->arrays.erase(array_name); + return true; + } + return false; +} + +std::vector>::const_iterator FindOpWithOutput( + const Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& output : it->get()->outputs) { + if (output == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +std::vector>::iterator FindOpWithOutput( + Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& output : it->get()->outputs) { + if (output == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +Operator* GetOpWithOutput(const Model& model, const string& array_name) { + auto it = FindOpWithOutput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +// GetFirstOpWithInput assumes that this finds the first op. +std::vector>::const_iterator FindOpWithInput( + const Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +std::vector>::const_iterator FindOp( + const Model& model, const Operator* op) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + if (it->get() == op) { + return it; + } + } + return model.operators.end(); +} + +std::vector>::iterator FindOp(Model& model, + const Operator* op) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + if (it->get() == op) { + return it; + } + } + return model.operators.end(); +} + +Operator* GetOpWithInput(const Model& model, const string& array_name) { + auto it = FindOpWithInput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +Operator* GetFirstOpWithInput(const Model& model, const string& array_name) { + auto it = FindOpWithInput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +string FormatArraysList(const Model& model, const std::vector& list) { + if (list.empty()) { + return "[]"; + } + string result = ""; + if (list.size() > 1) { + result += "[ "; + } + for (std::size_t i = 0; i < list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += list[i]; + } + if (list.size() > 1) { + result += " ]"; + } + return result; +} + +const char* OperatorTypeName(OperatorType type) { + switch (type) { +#define HANDLE_OPERATORTYPENAME_CASE(c) \ + case OperatorType::k##c: \ + return #c; + HANDLE_OPERATORTYPENAME_CASE(Add) + HANDLE_OPERATORTYPENAME_CASE(AveragePool) + HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) + HANDLE_OPERATORTYPENAME_CASE(Conv) + HANDLE_OPERATORTYPENAME_CASE(Concatenation) + HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv) + HANDLE_OPERATORTYPENAME_CASE(DepthToSpace) + HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth) + HANDLE_OPERATORTYPENAME_CASE(FullyConnected) + HANDLE_OPERATORTYPENAME_CASE(Dequantize) + HANDLE_OPERATORTYPENAME_CASE(L2Normalization) + HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization) + HANDLE_OPERATORTYPENAME_CASE(Logistic) + HANDLE_OPERATORTYPENAME_CASE(LstmCell) + HANDLE_OPERATORTYPENAME_CASE(MaxPool) + HANDLE_OPERATORTYPENAME_CASE(L2Pool) + HANDLE_OPERATORTYPENAME_CASE(FakeQuant) + HANDLE_OPERATORTYPENAME_CASE(Mul) + HANDLE_OPERATORTYPENAME_CASE(Relu) + HANDLE_OPERATORTYPENAME_CASE(Relu1) + HANDLE_OPERATORTYPENAME_CASE(Relu6) + HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) + HANDLE_OPERATORTYPENAME_CASE(Softmax) + HANDLE_OPERATORTYPENAME_CASE(Div) + HANDLE_OPERATORTYPENAME_CASE(Tanh) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Pad) + HANDLE_OPERATORTYPENAME_CASE(StridedSlice) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Squeeze) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Slice) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Sub) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Cast) + HANDLE_OPERATORTYPENAME_CASE(Floor) + HANDLE_OPERATORTYPENAME_CASE(Gather) + HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear) + HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) + HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) + HANDLE_OPERATORTYPENAME_CASE(Mean) + HANDLE_OPERATORTYPENAME_CASE(Svdf) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + default: + LOG(FATAL) << "Unhandled op type"; +#undef HANDLE_OPERATORTYPENAME_CASE + } +} + +string HelpfulOperatorTypeName(const Operator& op) { + if (op.type == OperatorType::kTensorFlowUnsupported) { + return toco::port::StringF( + "(Unsupported TensorFlow op: %s)", + static_cast(op).tensorflow_op); + } + return OperatorTypeName(op.type); +} + +void LogSummary(int log_level, const Model& model) { + VLOG(log_level) << "Operators summary (" << model.operators.size() + << " operators): "; + std::unordered_multiset ops_by_type; + for (const auto& op : model.operators) { + ops_by_type.insert(op->type); + } + auto it = ops_by_type.begin(); + while (it != ops_by_type.end()) { + int count = ops_by_type.count(*it); + VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count; + std::advance(it, count); + } +} + +void LogArray(int log_level, const Model& model, const string& name) { + const auto& array = model.GetArray(name); + VLOG(log_level) << "Array: " << name; + switch (array.data_type) { + case ArrayDataType::kNone: + break; + case ArrayDataType::kFloat: + VLOG(log_level) << " Data type: kFloat"; + break; + case ArrayDataType::kInt32: + VLOG(log_level) << " Data type: kInt32"; + break; + case ArrayDataType::kUint8: + VLOG(log_level) << " Data type: kUint8"; + break; + default: + VLOG(log_level) << " Data type: other (numerical value: " + << static_cast(array.data_type) << ")"; + break; + } + if (array.buffer) { + VLOG(log_level) << " Constant Buffer"; + } + if (array.alloc) { + VLOG(log_level) << " Transient Alloc"; + } + if (array.has_shape()) { + const Shape& array_shape = array.shape(); + if (array_shape.dimensions_count() == 0) { + VLOG(log_level) << " (Zero dimensions)"; + } else { + string message = " Dims: "; + bool first = true; + for (const int dim : array_shape.dims()) { + if (!first) { + message += ", "; + } + first = false; + toco::port::AppendF(&message, "%d", dim); + } + VLOG(log_level) << message; + } + } + if (array.minmax) { + VLOG(log_level) << " MinMax: " << array.minmax->min << " .. " + << array.minmax->max; + } + if (array.quantization_params) { + VLOG(log_level) << " QuantizationParams: zero_point=" + << array.quantization_params->zero_point + << ", scale=" << array.quantization_params->scale; + } +} + +void DumpGraphvizVideoFrame(const Model& model) { + namespace port = toco::port; + + const auto& dump_options = *GraphVizDumpOptions::singleton(); + if (!dump_options.dump_graphviz_video) { + return; + } + CHECK(!dump_options.dump_graphviz.empty()); + // TODO(benoitjacob): the static data here means that this function + // is stateful, not reentrant, and effectively leaks memory till exit + // (since dump_hashes can only grow in size). It also means that it + // really only is intended to be called for a single model during the + // process' lifetime. So it's not great design at all. The overriding + // design aspect here is to make the video-dumping code as unintrusive + // and self-contained as possible. Eventually, we'll want to have that + // cleaned-up, but that will require some form of general statefulness + // in toco (some kind of 'tooling state' data structure) that does + // not exist at present, and would be premature to design here just for + // this new video-dumping feature. + static int dump_id = 0; + static std::unordered_set dump_hashes; + string graphviz_dump; + DumpGraphviz(model, &graphviz_dump); + std::size_t hash = std::hash{}(graphviz_dump); + if (!dump_hashes.count(hash)) { + dump_hashes.insert(hash); + CHECK(port::file::SetContents( + port::file::JoinPath( + dump_options.dump_graphviz, + toco::port::StringF("toco_video_%05d.dot", dump_id)), + graphviz_dump, port::file::Defaults()) + .ok()); + dump_id++; + } +} + +void LogDump(int log_level, const string& message, const Model& model) { + namespace port = toco::port; + const auto& dump_options = *GraphVizDumpOptions::singleton(); + + DumpGraphvizVideoFrame(model); + if (!dump_options.dump_graphviz.empty()) { + string graphviz_dump; + + DumpGraphviz(model, &graphviz_dump); + CHECK(port::file::SetContents( + port::file::JoinPath( + dump_options.dump_graphviz, + absl::StrCat("toco_", + absl::StrReplaceAll(message, {{" ", "_"}}), + ".dot")), + graphviz_dump, port::file::Defaults()) + .ok()); + } + + if (!VLOG_IS_ON(log_level)) { + return; + } + VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")"; + LogSummary(log_level, model); + std::unordered_set already_printed_arrays; + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + if (!already_printed_arrays.count(input)) { + already_printed_arrays.insert(input); + LogArray(log_level, model, input); + } + } + VLOG(log_level) << HelpfulOperatorTypeName(*op) << " : "; + VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> " + << FormatArraysList(model, op->outputs); + if (op->fused_activation_function != FusedActivationFunctionType::kNone) { + VLOG(log_level) << " (with fused activation function)"; + } + for (const auto& output : op->outputs) { + if (!already_printed_arrays.count(output)) { + already_printed_arrays.insert(output); + LogArray(log_level, model, output); + } + } + } + VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")"; +} + +// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator(). +void ExtendShape(Shape* shape, int new_shape_size) { + CHECK_GE(new_shape_size, shape->dimensions_count()); + const int size_increase = new_shape_size - shape->dimensions_count(); + auto* shape_dims = shape->mutable_dims(); + shape_dims->insert(shape_dims->begin(), size_increase, 1); +} + +// TODO(b/62904716) Remove along with remaining uses. +void UnextendShape(Shape* shape, int new_shape_size) { + CHECK_LE(new_shape_size, shape->dimensions_count()); + const int size_reduction = shape->dimensions_count() - new_shape_size; + for (int i = 0; i < size_reduction; i++) { + CHECK_EQ(shape->dims(i), 1); + } + std::vector& shape_dims = *shape->mutable_dims(); + shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); +} + +void CheckShapeDimensions(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i + << ". shape = " << ShapeToString(shape); + } +} + +bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { + CheckShapeDimensions(shape0); + CheckShapeDimensions(shape1); + + const Shape* longer = &shape0; + const Shape* shorter = &shape1; + if (shape1.dimensions_count() > shape0.dimensions_count()) { + longer = &shape1; + shorter = &shape0; + } + + // Walk dimensions back to front until we run out of dimensions in the shorter + // shape. + int longer_index = longer->dimensions_count() - 1; + int shorter_index = shorter->dimensions_count() - 1; + while (shorter_index >= 0) { + const int d_long = longer->dims(longer_index); + const int d_short = shorter->dims(shorter_index); + // Broadcasting fails if the dimensions are different *and* neither is 1. + if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) { + return false; + } + longer_index--; + shorter_index--; + } + return true; +} + +bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { + CheckShapeDimensions(shape0); + CheckShapeDimensions(shape1); + + const Shape* longer = &shape0; + const Shape* shorter = &shape1; + if (shape1.dimensions_count() > shape0.dimensions_count()) { + longer = &shape1; + shorter = &shape0; + } + + // Walk dimensions back to front until we run out of dimensions in the shorter + // shape. + int longer_index = longer->dimensions_count() - 1; + int shorter_index = shorter->dimensions_count() - 1; + while (shorter_index >= 0) { + const int d_long = longer->dims(longer_index); + const int d_short = shorter->dims(shorter_index); + // Extending fails if the dimensions are different. + if (d_long != d_short) { + return false; + } + longer_index--; + shorter_index--; + } + + // The remaining dimensions in the longer shape must be 1. + while (longer_index >= 0) { + const int d_long = longer->dims(longer_index); + if (d_long != 1) { + return false; + } + longer_index--; + } + + return true; +} + +int RequiredBufferSizeForShape(const Shape& shape) { + int max_offset = 1; + for (const auto& dim : shape.dims()) { + CHECK_GE(dim, 1); + max_offset *= dim; + } + return max_offset; +} + +bool IsConstantParameterArray(const Model& model, const string& name) { + if (!model.arrays.count(name)) { + return false; + } + + return !!model.arrays.at(name)->buffer; +} + +void CheckNoMissingArray(const Model& model) { + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + CHECK(model.arrays.count(input)); + } + for (const auto& output : op->outputs) { + CHECK(model.arrays.count(output)); + } + } + for (const auto& input_array : model.flags.input_arrays()) { + CHECK(model.arrays.count(input_array.name())) + << "Input array not found: " << input_array.name(); + } + for (const string& output_array : model.flags.output_arrays()) { + CHECK(model.arrays.count(output_array)) + << "Output array not found: " << output_array; + } + for (const auto& rnn_state : model.flags.rnn_states()) { + CHECK(model.arrays.count(rnn_state.state_array())); + CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + } +} + +void FixNoMissingArray(Model* model) { + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + if (!model->arrays.count(input)) { + model->GetOrCreateArray(input); + } + } + for (const auto& output : op->outputs) { + if (!model->arrays.count(output)) { + model->GetOrCreateArray(output); + } + } + } + for (const string& output_array : model->flags.output_arrays()) { + if (!model->arrays.count(output_array)) { + model->GetOrCreateArray(output_array); + } + } +} + +void CheckNoOrphanedArray(const Model& model) { + std::unordered_set arrays_without_known_use; + for (const auto& array : model.arrays) { + arrays_without_known_use.insert(array.first); + } + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + arrays_without_known_use.erase(input); + } + for (const auto& output : op->outputs) { + arrays_without_known_use.erase(output); + } + } + if (!arrays_without_known_use.empty()) { + for (const auto& array : arrays_without_known_use) { + LOG(INFO) << "Error: Orphaned array: " << array; + } + } + CHECK(arrays_without_known_use.empty()); +} + +void FixNoOrphanedArray(Model* model) { + std::unordered_set arrays_without_known_use; + for (const auto& array : model->arrays) { + arrays_without_known_use.insert(array.first); + } + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + arrays_without_known_use.erase(input); + } + for (const auto& output : op->outputs) { + arrays_without_known_use.erase(output); + } + } + for (const auto& array : arrays_without_known_use) { + model->arrays.erase(array); + } +} + +void CheckArrayFieldsConsistent(const Model& model) { + for (const auto& array_entry : model.arrays) { + const auto& array = array_entry.second; + if (array->has_shape()) { + for (int d : array->shape().dims()) { + CHECK_GE(d, 1); + } + } + // It's OK to have a buffer or an alloc, but not both. + // (Since allocs are for transient arrays without a buffer). + CHECK(!array->buffer || !array->alloc); + // If there is a buffer, its type should be consistent with data_type. + if (array->buffer) { + CHECK(array->buffer->type == array->data_type); + } + } +} + +void CheckOperatorOrdering(const Model& model) { + std::unordered_set arrays_behind_us; + for (const auto& array_entry : model.arrays) { + if (!GetOpWithOutput(model, array_entry.first)) { + arrays_behind_us.insert(array_entry.first); + } + } + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(model, input)) { + CHECK(arrays_behind_us.count(input)); + } + } + for (const auto& output : op->outputs) { + CHECK(!arrays_behind_us.count(output)); + arrays_behind_us.insert(output); + } + } + for (const string& output_array : model.flags.output_arrays()) { + CHECK(arrays_behind_us.count(output_array)); + } +} + +void FixOperatorOrdering(Model* model) { + std::unordered_set arrays_behind_us; + for (const auto& array_entry : model->arrays) { + if (!GetOpWithOutput(*model, array_entry.first)) { + arrays_behind_us.insert(array_entry.first); + } + } + std::vector> old_operators; + std::swap(old_operators, model->operators); + std::set remaining; + for (std::size_t i = 0; i < old_operators.size(); i++) { + remaining.insert(i); + } + std::unordered_map reason_why_leftover; + while (true) { + bool inserted_something = false; + for (auto i : remaining) { + bool can_insert = true; + auto& op = old_operators[i]; + CHECK(op.get()); + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(*model, input) && + !arrays_behind_us.count(input)) { + for (const string& output : op->outputs) { + reason_why_leftover[output] = input; + } + can_insert = false; + break; + } + } + if (can_insert) { + model->operators.emplace_back(nullptr); + for (const auto& output : op->outputs) { + arrays_behind_us.insert(output); + } + std::swap(op, model->operators.back()); + remaining.erase(i); + inserted_something = true; + break; + } + } + if (!inserted_something) { + break; + } + } + if (!remaining.empty()) { + LOG(ERROR) + << "No viable ordering of operators was found. " + << "Here is a 'backtrace' of at least one part of the graph that is " + << "problematic. It starts with the first operator that has as " + << "problematic input array, and then walks back the graph to " + << "the operator that produced that input array, etc., until we find " + << "the root cause:"; + LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT"; + LOG(ERROR) << "Here is the first-encountered operator with a bad input: "; + const Operator* bad_op = old_operators[*remaining.begin()].get(); + std::unordered_set bad_inputs_already_traced; + // The following while(true) loop should always end with a LOG(FATAL). + while (true) { + LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : " + << FormatArraysList(*model, bad_op->inputs) << " -> " + << FormatArraysList(*model, bad_op->outputs); + bool found_bad_output = false; + string bad_output; + for (const string& output : bad_op->outputs) { + if (reason_why_leftover.count(output)) { + found_bad_output = true; + bad_output = output; + break; + } + } + CHECK(found_bad_output); + const string& bad_input = reason_why_leftover[bad_output]; + LOG(ERROR) << "The bad input here is: " << bad_input; + if (bad_inputs_already_traced.count(bad_input)) { + LOG(FATAL) + << "Cycle found! We already encountered that " + << "input array, " << bad_input << ", earlier in the " + << "above trace! We expect graphs to be acyclic, even " + << "RNNs. Let us know if some graph actually needs to have " + << "cycles, but first, please check if it really is " + << "an *inference* graph. *Training* graphs are out-of-scope " + << "for toco."; + } + bad_inputs_already_traced.insert(bad_input); + bad_op = nullptr; + for (auto i : remaining) { + const Operator* op = old_operators[i].get(); + for (const string& output : op->outputs) { + if (bad_input == output) { + bad_op = op; + break; + } + } + if (bad_op) { + break; + } + } + if (!bad_op) { + LOG(ERROR) << "And that's the root cause: " + << "that array, " << bad_input << ", isn't produced by any " + << "operator, or provided in any other way."; + LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT"; + LOG(FATAL) << "(The above was a multi-line fatal error)"; + } + LOG(ERROR) << "And that array is the output of the following operator:"; + } + } + CHECK(remaining.empty()) + << "Should never get here! In case of bad graph, " + << "the above code should have generated a FATAL error already!"; +} + +// Checks that the --input_arrays of the Model are actually used by at least +// one of the --output_arrays i.e. that the graph contains a path from each one +// of the inputs to at least one of the outputs. This catches cases where the +// user passed the wrong --input_arrays or --output_arrays, which otherwise may +// result in cryptic error messages. +void CheckInputUsedByOutputs(const Model& model) { + std::set used_arrays; + for (const string& output : model.flags.output_arrays()) { + used_arrays.insert(output); + } + for (int i = model.operators.size() - 1; i >= 0; i--) { + bool is_op_used = false; + for (const string& op_output : model.operators[i]->outputs) { + if (used_arrays.count(op_output)) { + is_op_used = true; + break; + } + } + if (!is_op_used) { + continue; + } + for (const string& op_input : model.operators[i]->inputs) { + used_arrays.insert(op_input); + } + } + for (const auto& input_array : model.flags.input_arrays()) { + QCHECK(used_arrays.count(input_array.name())) + << "The graph does not connect the input (" << input_array.name() + << ") specified by --input_arrays to any of the specified " + << "--output_arrays (" + << absl::StrJoin(model.flags.output_arrays(), ", ") + << "). Did you pass the wrong flags for this model, " + << "or is that model's graph actually incomplete?"; + } +} + +void CheckInvariants(const Model& model) { + CheckNoMissingArray(model); + CheckNoOrphanedArray(model); + CheckArrayFieldsConsistent(model); + CheckOperatorOrdering(model); + CheckInputUsedByOutputs(model); +} + +void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check, + const int count, const string& count_description) { + if (model_check.count_min() >= 0) { + CHECK_GE(count, model_check.count_min()) + << "Mismatch in " << count_description << ": count was " << count + << ", but the specified " + << (model_check.count_max() > model_check.count_min() ? "minimum" + : "value") + << " was " << model_check.count_min() << "."; + } + if (model_check.count_max() > model_check.count_min()) { + CHECK_LE(count, model_check.count_max()) + << "Mismatch in " << count_description << ": count was " << count + << ", but the specified maximum was " << model_check.count_max() << "."; + } +} + +void CheckModelCounts(const Model& model) { + std::unordered_multiset ops_by_type; + std::unordered_map op_type_by_name; + if (model.flags.model_checks_size() == 0) { + return; + } + + for (const auto& op : model.operators) { + ops_by_type.insert(op->type); + op_type_by_name[OperatorTypeName(op->type)] = op->type; + } + for (const auto& model_check : model.flags.model_checks()) { + string count_type = model_check.count_type(); + if (count_type == "None") { + continue; + } else if (count_type == "Arrays") { + CheckCountInRange(model_check, model.arrays.size(), "count of arrays"); + } else if (count_type == "Total") { + CheckCountInRange(model_check, model.operators.size(), + "count of all operator instances"); + } else { + // The check type is not itself checked against the set of valid + // operators, mainly because the enum set cannot be iterated in C++. + const int found_count = + op_type_by_name.count(count_type) > 0 + ? ops_by_type.count(op_type_by_name[count_type]) + : 0; + CheckCountInRange(model_check, found_count, + "count of instances of " + count_type + " operator"); + } + } +} + +void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, + std::vector* out_dims) { + CHECK(out_dims->empty()); + if (num_dims == 1) { + CHECK_EQ(batch, 1); + *out_dims = {depth}; + } else if (num_dims == 2) { + *out_dims = {batch, depth}; + } else if (num_dims == 3) { + CHECK_EQ(batch, 1); + *out_dims = {height, width, depth}; + } else if (num_dims == 4) { + *out_dims = {batch, height, width, depth}; + } else { + LOG(FATAL) << "Should not get here: " << num_dims; + } +} + +void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { + int batch = 1; + int num_dims = -1; + for (const auto& input_array : model->flags.input_arrays()) { + // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find + // a better match by name. + if (input_array.name() == name || num_dims == -1) { + num_dims = input_array.shape_size(); + if (num_dims != 0) { + batch = input_array.shape(0); + } + } + } + Array& array = model->GetOrCreateArray(name); + if (array.has_shape()) { + num_dims = array.shape().dimensions_count(); + } + std::vector dims; + MakeArrayDims(num_dims, batch, 1, 1, size, &dims); + CHECK(array.data_type == ArrayDataType::kFloat || + array.data_type == ArrayDataType::kNone); + array.data_type = ArrayDataType::kFloat; + if (!array.has_shape()) { + Shape* shape = array.mutable_shape(); + *shape->mutable_dims() = dims; + } +} + +void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { + // Merge info about input_arrays from model_flags into model->flags + for (const auto& specified_input_array : model_flags.input_arrays()) { + toco::InputArray* dst_input_array = nullptr; + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + toco::InputArray* candidate_dst_input_array = + model->flags.mutable_input_arrays(i); + if (candidate_dst_input_array->name() == specified_input_array.name()) { + // specified_input_array from model_flags maps to dst_input_array + // in model->flags + dst_input_array = candidate_dst_input_array; + break; + } + } + if (!dst_input_array) { + // specified_input_array from model_flags is not found in model->flags. + // Match a name-less specified input array when there can be no ambiguity + // as there is only 1 input array. + if (model->flags.input_arrays_size() == 1 && + model_flags.input_arrays_size() == 1 && + !specified_input_array.has_name()) { + dst_input_array = model->flags.mutable_input_arrays(0); + } + } + if (!dst_input_array) { + // Still no match, so create a new input array to copy + // specified_input_array into. + dst_input_array = model->flags.add_input_arrays(); + dst_input_array->set_name(specified_input_array.name()); + } + +#define RESOLVE_MODEL_FLAG(field_name) \ + if (specified_input_array.has_##field_name()) { \ + if (dst_input_array->has_##field_name()) { \ + QCHECK_EQ(dst_input_array->field_name(), \ + specified_input_array.field_name()) \ + << "For input array '" << dst_input_array->name() << "', " \ + << "specified " #field_name " flag with value: " \ + << specified_input_array.field_name() \ + << " does not agree with already defined " #field_name \ + " of this model, with value: " \ + << specified_input_array.field_name(); \ + } else { \ + dst_input_array->set_##field_name(specified_input_array.field_name()); \ + } \ + } + RESOLVE_MODEL_FLAG(std_value); + RESOLVE_MODEL_FLAG(mean_value); +#undef RESOLVE_MODEL_FLAG + + if (!specified_input_array.shape().empty()) { + if (!dst_input_array->shape().empty()) { + QCHECK_EQ(specified_input_array.shape().size(), + dst_input_array->shape().size()) + << "For input array '" << specified_input_array.name() << "', " + << "size of specified input shape flag with size: " + << specified_input_array.shape().size() + << " does not agree with already defined input shape" + " of this model, with size: " + << dst_input_array->shape().size(); + // We treat the first dimension as a special case, since it is often + // a batch size and the input_shape flag is effectively overriding + // the model. + for (int i = 1; i < specified_input_array.shape().size(); i++) { + QCHECK_EQ(specified_input_array.shape().Get(i), + dst_input_array->shape().Get(i)) + << "At dimension number " << i << " of input array " + << specified_input_array.name() << ", the specified shape's " + << "dimension flag with dimension: " + << specified_input_array.shape().Get(i) + << " does not agree with already defined shape" + << " of this model, with dimension: " + << dst_input_array->shape().Get(i); + } + } else { + dst_input_array->mutable_shape()->CopyFrom( + specified_input_array.shape()); + } + } + } + + if (model_flags.output_arrays_size() > 0) { + model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays()); + } + +#define RESOLVE_MODEL_FLAG(name) \ + if (model_flags.has_##name()) { \ + if (model->flags.has_##name()) { \ + QCHECK_EQ(model_flags.name(), model->flags.name()) \ + << "Specified " #name " flag with value: " << model_flags.name() \ + << " does not agree with already defined " #name \ + " of this model, with value: " \ + << model->flags.name(); \ + } else { \ + model->flags.set_##name(model_flags.name()); \ + } \ + } + + RESOLVE_MODEL_FLAG(variable_batch) + RESOLVE_MODEL_FLAG(drop_control_dependency) + +#undef RESOLVE_MODEL_FLAG + + if (model->flags.rnn_states_size() == 0) { + model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states()); + } else { + CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size()); + for (int i = 0; i < model->flags.rnn_states_size(); i++) { + CHECK_EQ(model->flags.rnn_states(i).state_array(), + model_flags.rnn_states(i).state_array()); + CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(), + model_flags.rnn_states(i).back_edge_source_array()); + } + } + + if (model->flags.model_checks_size() == 0) { + model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks()); + } + + QCHECK_GT(model->flags.input_arrays_size(), 0) + << "This model does not define input arrays, so a " + "--input_arrays flag must be given on the command-line."; + QCHECK_GT(model->flags.output_arrays_size(), 0) + << "This model does not define output arrays, so a " + "--output_arrays flag must be given on the command-line."; + + for (const auto& input_array_proto : model->flags.input_arrays()) { + QCHECK(!input_array_proto.shape().empty()) + << "This model does not have shape defined for input array " + << input_array_proto.name() + << ", so one must be specified by a non-empty --input_shape " + "command-line flag."; + + auto& input_array = model->GetOrCreateArray(input_array_proto.name()); + if (input_array.data_type == ArrayDataType::kNone) { + // We start out with a float input array; + // that may get replaced by a uint8 array later, by + // MakeInitialDequantizeOp. + input_array.data_type = ArrayDataType::kFloat; + } + + // Compare/merge the model->flags describing the input_shape with + // the actual input array's shape. + auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); + if (input_array_dims.empty()) { + for (auto dim : input_array_proto.shape()) { + CHECK_GE(dim, 1); + input_array_dims.push_back(dim); + } + } else { + CHECK_EQ(input_array_dims.size(), input_array_proto.shape_size()); + for (int i = 0; i < input_array_dims.size(); i++) { + CHECK_EQ(input_array_dims[i], input_array_proto.shape(i)); + } + } + + const float mean_value = input_array_proto.mean_value(); + const float std_value = input_array_proto.std_value(); + MinMax input_minmax; + input_minmax.min = (0.f - mean_value) / std_value; + input_minmax.max = (255.f - mean_value) / std_value; + if (input_array.minmax) { + if (input_array_proto.has_mean_value() || + input_array_proto.has_std_value()) { + CHECK(input_minmax == *input_array.minmax) + << input_minmax.min << ", " << input_minmax.max + << " != " << input_array.minmax->min << ", " + << input_array.minmax->max; + } + } else { + input_array.GetOrCreateMinMax() = input_minmax; + } + } + // Creation of the RNN state arrays + for (const auto& rnn_state : model->flags.rnn_states()) { + if (!rnn_state.manually_create()) { + continue; + } + CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), + model); + } +} + +void CheckIsReadyForQuantization(const Model& model) { + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + const auto& input_array = model.GetArray(input); + if (input_array.data_type != ArrayDataType::kFloat) { + // The array is not floats, no quantization needed. + continue; + } + if (input_array.minmax) { + // The array has minmax, we're good. + continue; + } + if (input_array.buffer) { + // The array has a constant buffer, so we can + // fall back to computing the minmax from actual array entries + // (with a WARNING about possible accuracy implications). + continue; + } + LOG(FATAL) + << "Array " << input << ", which is an input to the " + << HelpfulOperatorTypeName(*op) << " operator producing the output " + << "array " << op->outputs[0] << ", is lacking min/max data, " + << "which is necessary for quantization. Either target a " + << "non-quantized output format, or change the input graph to " + << "contain min/max information, or pass --default_ranges_min= and " + << "--default_ranges_max= if you do not care about the accuracy of " + << "results."; + } + } +} + +void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, + double default_ranges_max) { + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + auto& input_array = model->GetArray(input); + if (!input_array.minmax && !input_array.buffer) { + auto& minmax = input_array.GetOrCreateMinMax(); + minmax.min = default_ranges_min; + minmax.max = default_ranges_max; + } + } + for (const auto& output : op->outputs) { + auto& output_array = model->GetArray(output); + if (!output_array.minmax && !output_array.buffer) { + auto& minmax = output_array.GetOrCreateMinMax(); + minmax.min = default_ranges_min; + minmax.max = default_ranges_max; + } + } + } +} + +int ElementSize(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kFloat: + return 4; + case ArrayDataType::kInt32: + return 4; + case ArrayDataType::kUint8: + return 1; + default: + LOG(FATAL) << "Should not get here."; + return 0; + } +} + +void DropMinMax(Model* model, const string& array_name) { + auto& array = model->GetArray(array_name); + if (!!array.minmax) { + LOG(WARNING) << "Dropping MinMax information in array " << array_name + << ". Expect inaccuracy in quantized inference."; + array.minmax = nullptr; + } +} + +bool IsAllocatableTransientArray(const Model& model, const string& array_name) { + // The model's input and output arrays are externally allocated. + // They are not transient arrays. + if (IsInputArray(model, array_name)) { + return false; + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return false; + } + } + const auto& array = model.arrays.at(array_name); + // An array with a constant buffer isn't a transient array. + if (!!array->buffer) { + return false; + } + // An array without shape isn't allocatable. + if (!array->has_shape()) { + return false; + } + return true; +} + +string AvailableArrayName(const Model& model, const string& name) { + if (!model.arrays.count(name)) { + return name; + } + const int kNumSuffixesToTry = 1000; + for (int i = 0; i < kNumSuffixesToTry; i++) { + const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); + if (!model.arrays.count(name_with_suffix)) { + return name_with_suffix; + } + } + LOG(FATAL) << "Could not find an available array name starting with " << name + << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!"; + return ""; +} + +string ShapeToString(const Shape& shape) { + if (shape.dimensions_count() == 0) { + return "[]"; + } + + return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]"); +} + +void PrintArrayShape(Model* model, const string& name) { + if (!model->arrays[name]->has_shape()) { + LOG(INFO) << name << " has no shape"; + return; + } + LOG(INFO) << name + << " has shape: " << ShapeToString(model->arrays[name]->shape()); +} + +bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { + bool is_fc_weights = false; + bool is_something_else = false; + for (const auto& op : model.operators) { + for (int input_index = 0; input_index < op->inputs.size(); input_index++) { + if (op->inputs[input_index] == name) { + if (op->type == OperatorType::kFullyConnected && input_index == 1) { + is_fc_weights = true; + } else { + is_something_else = true; + } + } + } + } + CHECK(!(is_fc_weights && is_something_else)); + return is_fc_weights; +} + +bool EstimateArithmeticOpsCount(const Model& model, int64* result) { + int64 total = 0; + for (const auto& op : model.operators) { + switch (op->type) { + case OperatorType::kFullyConnected: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: { + const auto& output_array = model.GetArray(op->outputs[0]); + const auto& weights_array = model.GetArray(op->inputs[1]); + if (!output_array.has_shape() || !weights_array.has_shape()) { + return false; + } + int cols = 1; + for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) { + cols *= output_array.shape().dims(i); + } + const int64 cost_per_col = + 2 * RequiredBufferSizeForShape(weights_array.shape()); + total += cost_per_col * cols; + if (op->inputs.size() > 2) { + // There is a bias vector. One more op per output value. + total += RequiredBufferSizeForShape(output_array.shape()); + } + break; + } + case OperatorType::kAdd: + case OperatorType::kSub: + case OperatorType::kMul: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()); + break; + } + case OperatorType::kLogistic: + case OperatorType::kSoftmax: + case OperatorType::kTanh: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // As a very rough ballpark, the cost of evaluating a math function + // such as tanh or logistic is about 32 multiplications, and about as + // many additions/subtractions. (Just a power-of-two order-of-magnitude + // from looking at actual implementations that we use in runtime/ code). + total += 64 * RequiredBufferSizeForShape(output_array.shape()); + break; + } + case OperatorType::kMaxPool: { + const auto& maxpool = *static_cast(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()) * + maxpool.kheight * maxpool.kwidth; + break; + } + case OperatorType::kAveragePool: { + const auto& avgpool = + *static_cast(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()) * + avgpool.kheight * avgpool.kwidth; + break; + } + case OperatorType::kL2Pool: { + const auto* maxpool = static_cast(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // The sum of squares requires (kheight*kwidth) multiply-adds, + // and then there is the sqrt which we ballpark at 32 ops. + const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32; + total += + RequiredBufferSizeForShape(output_array.shape()) * cost_per_val; + break; + } + case OperatorType::kL2Normalization: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // Computing the squared L2 norm is N multiply-adds so 2N ops, + // then the single inverse-sqrt is negligible, then we multiply each + // value by the resulting multiplier, so an extra N ops. Total 3N ops. + total += 3 * RequiredBufferSizeForShape(output_array.shape()); + break; + } + default: + break; + } + } + *result = total; + return true; +} + +namespace { + +void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, + std::vector* shuffle) { + CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order)); + shuffle->resize(4); + for (int i = 0; i < 4; i++) { + (*shuffle)[i] = i; + } + if (input_axes_order == output_axes_order) { + // nothing to do + } else if (AxesCount(input_axes_order) == 2) { + shuffle->resize(2); + (*shuffle)[0] = 1; + (*shuffle)[1] = 0; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWIO) { + // 3210 <- 3210 + // HWIO <- OHWI + (*shuffle)[0] = 1; + (*shuffle)[1] = 2; + (*shuffle)[2] = 3; + (*shuffle)[3] = 0; + } else if (input_axes_order == AxesOrder::kHWIO && + output_axes_order == AxesOrder::kOHWI) { + // 3210 <- 3210 + // OHWI <- HWIO + (*shuffle)[0] = 3; + (*shuffle)[1] = 0; + (*shuffle)[2] = 1; + (*shuffle)[3] = 2; + } else { + LOG(FATAL) << "Bad shuffle"; + } +} + +// Extend shuffle is designed to match ExtendShape, which pads the shape with +// unit dimensions at the beginning. +void ExtendShuffle(const std::vector& input_shuffle, int newdim, + std::vector* extended_shuffle) { + *extended_shuffle = input_shuffle; + CHECK(newdim >= input_shuffle.size()); + const int pad_size = newdim - input_shuffle.size(); + extended_shuffle->resize(newdim); + for (int i = 0; i < pad_size; i++) { + (*extended_shuffle)[i] = i; + } + for (int i = pad_size; i < newdim; i++) { + (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size; + } +} + +} // end anonymous namespace + +void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, Shape* output_shape) { + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // This special case isn't just a permutation, the IM pair of dims get + // merged into the 3 dim, so we have to special-case it. + *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1), + input_shape.dims(3) * input_shape.dims(2)}); + } else { + std::vector shuffle; + GetShuffleShape(input_axes_order, output_axes_order, &shuffle); + std::vector* output_dims = output_shape->mutable_dims(); + output_dims->resize(input_shape.dimensions_count()); + for (int i = 0; i < input_shape.dimensions_count(); i++) { + (*output_dims)[i] = input_shape.dims(shuffle[i]); + } + } +} + +void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, const Shape& output_shape, + const float* input_data, float* output_data) { + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // This special case isn't just a permutation, the IM pair of dims get + // merged into the O dim, so we have to special-case it. Fortunately, + // as far as array shuffling is concerned, it's just the identity + // transformation. + memcpy(output_data, input_data, + RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0])); + return; + } + CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); + const int dim = input_shape.dimensions_count(); + CHECK_LE(dim, 4); + std::vector shuffle; + GetShuffleShape(input_axes_order, output_axes_order, &shuffle); + CHECK(shuffle.size() >= dim); + for (int i = 0; i < dim; i++) { + CHECK(shuffle[i] >= 0 && shuffle[i] < dim); + CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i)); + } + Shape extended_input_shape = input_shape; + ExtendShape(&extended_input_shape, 4); + Shape extended_output_shape = output_shape; + ExtendShape(&extended_output_shape, 4); + std::vector extended_shuffle; + ExtendShuffle(shuffle, 4, &extended_shuffle); + + const std::vector& extended_input_dims = extended_input_shape.dims(); + const std::vector& extended_output_dims = extended_output_shape.dims(); + + // TODO(starka): Rework to handle different numbers of dimensions. + int input_strides[4]; + input_strides[3] = 1; + input_strides[2] = extended_input_dims[3]; + input_strides[1] = input_strides[2] * extended_input_dims[2]; + input_strides[0] = input_strides[1] * extended_input_dims[1]; + const int input_stride_0 = input_strides[extended_shuffle[3]]; + const int input_stride_1 = input_strides[extended_shuffle[2]]; + const int input_stride_2 = input_strides[extended_shuffle[1]]; + const int input_stride_3 = input_strides[extended_shuffle[0]]; + + const int output_size_0 = extended_output_dims[3]; + const int output_size_1 = extended_output_dims[2]; + const int output_size_2 = extended_output_dims[1]; + const int output_size_3 = extended_output_dims[0]; + const int output_stride_0 = 1; + const int output_stride_1 = output_size_0; + const int output_stride_2 = output_stride_1 * output_size_1; + const int output_stride_3 = output_stride_2 * output_size_2; + + for (int i3 = 0; i3 < output_size_3; i3++) { + const float* const input_ptr_3 = input_data + i3 * input_stride_3; + float* const output_ptr_3 = output_data + i3 * output_stride_3; + for (int i2 = 0; i2 < output_size_2; i2++) { + const float* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2; + float* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; + for (int i1 = 0; i1 < output_size_1; i1++) { + const float* input_ptr = input_ptr_2 + i1 * input_stride_1; + float* output_ptr = output_ptr_2 + i1 * output_stride_1; + float* const output_ptr_end = + output_ptr + output_size_0 * output_stride_0; + while (output_ptr != output_ptr_end) { + *output_ptr = *input_ptr; + input_ptr += input_stride_0; + output_ptr += output_stride_0; + } + } + } + } +} + +int AxesCount(AxesOrder axes_order) { + switch (axes_order) { + case AxesOrder::kOneAxis: + return 1; + case AxesOrder::kRC: + return 2; + case AxesOrder::kCR: + return 2; + case AxesOrder::kHWIO: + return 4; + case AxesOrder::kOHWI: + return 4; + case AxesOrder::kHWIM: + return 4; + case AxesOrder::k1HWO: + return 4; + case AxesOrder::kNHWC: + return 4; + default: + LOG(FATAL) << "Bad AxesOrder"; + return 0; + } +} + +bool IsDiscardableArray(const Model& model, const string& array_name) { + for (const auto& input_array : model.flags.input_arrays()) { + if (array_name == input_array.name()) { + return false; + } + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return false; + } + } + for (const auto& rnn_state : model.flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + return false; + } + if (array_name == rnn_state.back_edge_source_array()) { + return false; + } + } + return true; +} + +void CheckFinalDataTypesSatisfied(const Model& model) { + for (const auto& array_entry : model.arrays) { + const auto& array = *array_entry.second; + if (array.final_data_type != ArrayDataType::kNone) { + CHECK(array.final_data_type == array.data_type); + } + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h new file mode 100644 index 0000000000..093945edb3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -0,0 +1,292 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "tensorflow/core/platform/logging.h" +#if TOCO_SUPPORT_PORTABLE_PROTOS +#include "third_party/protobuf/src/google/protobuf/text_format.h" +#endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" + +// TODO(aselle): Replace with using a container specific hash override instead. +namespace std { +template <> +struct hash { + size_t operator()(const toco::OperatorType& op) const { + return std::hash()(static_cast(op)); + } +}; +} // namespace std + +namespace toco { + +constexpr int kLogLevelModelChanged = 1; +constexpr int kLogLevelModelUnchanged = 2; + +string LogName(const Operator& op); + +bool IsInputArray(const Model& model, const string& name); +bool IsArrayConsumed(const Model& model, const string& name); +int CountTrueOutputs(const Model& model, const Operator& op); + +int CountOpsWithInput(const Model& model, const string& array_name); +bool DeleteArrayIfUnused(const string& array_name, Model* model); + +std::vector>::const_iterator FindOpWithOutput( + const Model& model, const string& array_name); +Operator* GetOpWithOutput(const Model& model, const string& array_name); + +std::vector>::iterator FindOpWithOutput( + Model& model, const string& array_name); +Operator* GetOpWithOutput(const Model& model, const string& array_name); + +std::vector>::const_iterator FindOpWithInput( + const Model& model, const string& array_name); +Operator* GetOpWithInput(const Model& model, const string& array_name); +Operator* GetFirstOpWithInput(const Model& model, const string& array_name); + +std::vector>::const_iterator FindOp( + const Model& model, const Operator* op); +std::vector>::iterator FindOp(Model& model, + const Operator* op); + +const char* OperatorTypeName(OperatorType type); +string HelpfulOperatorTypeName(const Operator& op); + +void DumpGraphvizVideoFrame(const Model& model); +void LogDump(int log_level, const string& message, const Model& model); +void LogSummary(int log_level, const string& message, const Model& model); + +inline bool ParseFromStringOverload(const std::string& in, + TFLITE_PROTO_NS::Message* proto) { + return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto); +} + +template +bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents, + Proto* proto) { + if (proto->ParseFromString(input_file_contents)) { + return true; + } + + if (ParseFromStringOverload(input_file_contents, proto)) { + return true; + } + + return false; +} + +// TODO(b/36075966): Clean up when dims superseded by array shape. +void ExtendShape(Shape* shape, int new_shape_size); + +// TODO(b/36075966): Clean up when dims superseded by array shape. +void UnextendShape(Shape* shape, int new_shape_size); + +// Checks (using CHECK) that all dimensions of 'shape' are at least 1. +void CheckShapeDimensions(const Shape& shape); + +// Given two shapes with potentially different dimensionality and dimension +// arrays d0 and d1. Without loss of generality, assume that shape0 may have +// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1 +// "agree up to broadcasting" if: +// - When walking the d0 and d1 from back to front with indices i0, i1, +// d0[i0] == d1[i1] or d0[i0] == 1 or d1[i1] == 1, for each dimension until +// i1 == 0 (inclusive). +bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1); + +// A stricter constraint than ShapesAgreeUpToBroadcasting(). +// +// Given two shapes with potentially different dimensionality and dimension +// arrays d0 and d1. Without loss of generality, assume that shape0 may have +// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1 +// "agree up to extending" if: +// - When walking the d0 and d1 from back to front with indices i0, i1, +// d0[i0] == d1[i1] for each dimension until i1 == 0 (inclusive). +// - For the remaining indices [0..i0), d0[i0] == 1. +bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1); + +bool IsArrayFullyConnectedWeights(const Model& model, const string& name); + +// If there is a wildcard dimension (-1), this may return a negative value. +int RequiredBufferSizeForShape(const Shape& shape); + +bool IsConstantParameterArray(const Model& model, const string& name); + +void CheckNoMissingArray(const Model& model); +void CheckInvariants(const Model& model); + +void CheckModelCounts(const Model& model); + +void FixOperatorOrdering(Model* model); +void FixNoMissingArray(Model* model); +void FixNoOrphanedArray(Model* model); + +void ResolveModelFlags(const ModelFlags& model_flags, Model* model); + +template +void GetQuantizationParamsFromMinMax(const ModelFlags& model_flags, + const MinMax& minmax, + QuantizationParams* quantization_params) { + using Integer = DataType; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const double qmin_double = qmin; + const double qmax_double = qmax; + const double rmin = minmax.min; + const double rmax = minmax.max; + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + CHECK_LE(rmin, 0.); + CHECK_GE(rmax, 0.); + if (rmin == rmax) { + // Special case where the min,max range is a point. Should be {0}. + CHECK_EQ(rmin, 0.); + CHECK_EQ(rmax, 0.); + quantization_params->zero_point = 0; + quantization_params->scale = 0.; + return; + } + + // General case. + // + // First determine the scale. + const double scale = (rmax - rmin) / (qmax_double - qmin_double); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + Integer nudged_zero_point = 0; + if (zero_point_double < qmin_double) { + nudged_zero_point = qmin; + } else if (zero_point_double > qmax_double) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = static_cast(std::round(zero_point_double)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + CHECK_GE(nudged_zero_point, qmin); + CHECK_LE(nudged_zero_point, qmax); + + // Finally, store the result nudged quantization params. + quantization_params->zero_point = nudged_zero_point; + quantization_params->scale = scale; +} + +void CheckIsReadyForQuantization(const Model& model); +void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, + double default_ranges_max); + +inline int Offset(const Shape& shape, const std::vector& indices) { + DCHECK_EQ(shape.dimensions_count(), indices.size()); + const int dims_count = shape.dimensions_count(); + int offset = 0; + for (int i = 0; i < dims_count; i++) { + const int index = indices[i]; + DCHECK(index >= 0 && index < shape.dims(i)); + offset *= shape.dims(i); + offset += index; + } + return offset; +} + +inline std::vector ReverseOffset(const Shape& shape, int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, RequiredBufferSizeForShape(shape)); + const int dims_count = shape.dimensions_count(); + std::vector indices(dims_count); + int residual = index; + for (int i = dims_count - 1; i >= 0; i--) { + indices[i] = residual % shape.dims(i); + residual /= shape.dims(i); + } + return indices; +} + +int ElementSize(ArrayDataType data_type); + +void DropMinMax(Model* model, const string& array_name); + +bool IsAllocatableTransientArray(const Model& model, const string& array_name); + +void CreateOrCheckRnnStateArray(const string& name, int size, Model* model); + +string AvailableArrayName(const Model& model, const string& name); + +// Formats a shape as a string: [ dims(0), dims(1), ..., dims(num_dims-1) ]. +string ShapeToString(const Shape& shape); + +void PrintArrayShape(Model* model, const string& name); + +void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, + std::vector* out_dims); + +bool EstimateArithmeticOpsCount(const Model& model, int64* result); + +int AxesCount(AxesOrder axes_order); + +void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, Shape* output_shape); +void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, const Shape& output_shape, + const float* input_data, float* output_data); + +// Returns true if it may be OK for any graph transformation to ever discard +// that array. The idea is that we can't ever discard arrays that are either +// an input or an output of the whole graph, or that appear in RNN back-edges, +// as that would undercut explicit flags that the user might pass. +bool IsDiscardableArray(const Model& model, const string& array_name); + +void CheckFinalDataTypesSatisfied(const Model& model); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc new file mode 100644 index 0000000000..22955ce956 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +enum class Agreement { kBroadcast, kExtend, kBroadcastNotExtend, kNeither }; + +// A pair of Shapes and whether they should agree up to broadcasting, extending +// or neither. +struct ShapePair { + Shape left; + Shape right; + Agreement agreement; +}; + +std::vector CreateShapePairs() { + return std::vector( + {// These agree up to broadcast. + {Shape({3}), Shape({3}), Agreement::kBroadcast}, + {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast}, + {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast}, + {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast}, + + // These extend (and therefore broadcast). + {Shape({3}), Shape({3}), Agreement::kExtend}, + {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kExtend}, + {Shape({1, 1, 3}), Shape({1, 1, 3}), Agreement::kExtend}, + {Shape({1, 1, 3}), Shape({3}), Agreement::kExtend}, + {Shape({1, 1, 3}), Shape({1, 3}), Agreement::kExtend}, + + // These strictly broadcast and do not extend. + {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcastNotExtend}, + {Shape({5, 4}), Shape({1}), Agreement::kBroadcastNotExtend}, + {Shape({5, 4}), Shape({4}), Agreement::kBroadcastNotExtend}, + {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend}, + {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend}, + {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend}, + + // These do not broadcast (and therefore also do not extend). + {Shape({3}), Shape({4}), Agreement::kNeither}, + {Shape({2, 1}), Shape({8, 4, 3}), Agreement::kNeither}}); +} + +// ShapeTest is an empty parameterized test fixture since there is no state. +class ShapeTest : public ::testing::TestWithParam {}; + +TEST_P(ShapeTest, Agrees) { + const ShapePair& param = GetParam(); + + switch (param.agreement) { + case Agreement::kBroadcast: { + EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + break; + } + case Agreement::kExtend: { + EXPECT_TRUE(ShapesAgreeUpToExtending(param.left, param.right)); + // Anything that extends should also broadcast. + EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + break; + } + case Agreement::kBroadcastNotExtend: { + // Verify that it strictly broadcasts but does not extend. + EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right)); + break; + } + case Agreement::kNeither: { + EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right)); + EXPECT_FALSE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + break; + } + } +} + +INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest, + ::testing::ValuesIn(CreateShapePairs())); + +} // namespace toco diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD new file mode 100644 index 0000000000..2d918fd4e8 --- /dev/null +++ b/tensorflow/contrib/lite/tools/BUILD @@ -0,0 +1,60 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_binary( + name = "generate_op_registrations", + srcs = ["gen_op_registration_main.cc"], + deps = [ + "//tensorflow/contrib/lite/tools:gen_op_registration", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gen_op_registration", + srcs = ["gen_op_registration.cc"], + hdrs = ["gen_op_registration.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "gen_op_registration_test", + srcs = ["gen_op_registration_test.cc"], + data = [ + "//tensorflow/contrib/lite:testdata/0_subgraphs.bin", + "//tensorflow/contrib/lite:testdata/2_subgraphs.bin", + "//tensorflow/contrib/lite:testdata/empty_model.bin", + "//tensorflow/contrib/lite:testdata/test_model.bin", + "//tensorflow/contrib/lite:testdata/test_model_broken.bin", + ], + deps = [ + ":gen_op_registration", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "mutable_op_resolver", + srcs = ["mutable_op_resolver.cc"], + hdrs = ["mutable_op_resolver.h"], + deps = ["//tensorflow/contrib/lite:framework"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.cc b/tensorflow/contrib/lite/tools/gen_op_registration.cc new file mode 100644 index 0000000000..57c2567e3b --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "third_party/re2/re2.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { + +string NormalizeCustomOpName(const string& op) { + string method(op); + RE2::GlobalReplace(&method, "([a-z])([A-Z])", "\\1_\\2"); + std::transform(method.begin(), method.end(), method.begin(), ::toupper); + return method; +} + +void ReadOpsFromModel(const ::tflite::Model* model, + std::vector* builtin_ops, + std::vector* custom_ops) { + if (!model) return; + auto opcodes = model->operator_codes(); + if (!opcodes) return; + for (const auto* opcode : *opcodes) { + if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { + builtin_ops->push_back( + tflite::EnumNameBuiltinOperator(opcode->builtin_code())); + } else { + custom_ops->push_back(opcode->custom_code()->c_str()); + } + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/contrib/lite/tools/gen_op_registration.h new file mode 100644 index 0000000000..363bb2335c --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ + +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { + +// Convert the custom op name to registration name following the convention. +// Example: +// "custom_op" -> "CUSTOM_OP" +// "CustomOp" -> "CUSTOM_OP" +// Note "Register_" suffix will be added later in the tool. +string NormalizeCustomOpName(const string& op); + +// Read ops from the TFLite model. +// Enum name of builtin ops will be stored, such as "CONV_2D". +// Custom op name will be stored as it is. +void ReadOpsFromModel(const ::tflite::Model* model, + std::vector* builtin_ops, + std::vector* custom_ops); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc new file mode 100644 index 0000000000..7b27066a21 --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Flag; +using tensorflow::Flags; + +namespace { + +void GenerateFileContent(const string& filename, + const std::vector& builtin_ops, + const std::vector& custom_ops) { + std::ofstream fout(filename); + + fout << "#include " + "\"third_party/tensorflow/contrib/lite/model.h\"\n"; + fout << "#include " + "\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n"; + fout << "namespace tflite {\n"; + fout << "namespace ops {\n"; + if (!builtin_ops.empty()) { + fout << "namespace builtin {\n"; + fout << "// Forward-declarations for the builtin ops.\n"; + for (const auto& op : builtin_ops) { + fout << "TfLiteRegistration* Register_" << op << "();\n"; + } + fout << "} // namespace builtin\n"; + } + + if (!custom_ops.empty()) { + fout << "namespace custom {\n"; + fout << "// Forward-declarations for the custom ops.\n"; + for (const auto& op : custom_ops) { + fout << "TfLiteRegistration* Register_" + << ::tflite::NormalizeCustomOpName(op) << "();\n"; + } + fout << "} // namespace custom\n"; + } + fout << "} // namespace ops\n"; + fout << "} // namespace tflite\n"; + + fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n"; + for (const auto& op : builtin_ops) { + fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op + << ", ::tflite::ops::builtin::Register_" << op << "());\n"; + } + for (const auto& op : custom_ops) { + fout << " resolver->AddCustom(\"" << op + << "\", ::tflite::ops::custom::Register_" + << ::tflite::NormalizeCustomOpName(op) << "());\n"; + } + fout << "}\n"; + fout.close(); +} +} // namespace + +int main(int argc, char** argv) { + string input_model; + string output_registration; + std::vector flag_list = { + Flag("input_model", &input_model, "path to the tflite model"), + Flag("output_registration", &output_registration, + "filename for generated registration code"), + }; + Flags::Parse(&argc, argv, flag_list); + + tensorflow::port::InitMain(argv[0], &argc, &argv); + std::vector builtin_ops; + std::vector custom_ops; + + std::ifstream fin(input_model); + std::stringstream content; + content << fin.rdbuf(); + const ::tflite::Model* model = ::tflite::GetModel(content.str().data()); + ::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops); + GenerateFileContent(output_registration, builtin_ops, custom_ops); + return 0; +} diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_test.cc b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc new file mode 100644 index 0000000000..c65cffe340 --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include +#include + +using ::testing::ElementsAreArray; + +namespace tflite { + +class GenOpRegistrationTest : public ::testing::Test { + protected: + GenOpRegistrationTest() {} + + void ReadOps(const string& model_path) { + auto model = FlatBufferModel::BuildFromFile(model_path.data()); + if (model) { + ReadOpsFromModel(model->GetModel(), &builtin_ops_, &custom_ops_); + } + } + + std::vector builtin_ops_; + std::vector custom_ops_; +}; + +TEST_F(GenOpRegistrationTest, TestNonExistantFiles) { + ReadOps("/tmp/tflite_model_1234"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestModels) { + ReadOps("third_party/tensorflow/contrib/lite/testdata/test_model.bin"); + EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"})); + EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"})); +} + +TEST_F(GenOpRegistrationTest, TestEmptyModels) { + ReadOps("third_party/tensorflow/contrib/lite/testdata/empty_model.bin"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestZeroSubgraphs) { + ReadOps("third_party/tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestBrokenMmap) { + ReadOps("third_party/tensorflow/contrib/lite/testdata/test_model_broken.bin"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestNormalizeCustomOpName) { + std::vector> testcase = { + {"CustomOp", "CUSTOM_OP"}, + {"a", "A"}, + {"custom_op", "CUSTOM_OP"}, + {"customop", "CUSTOMOP"}, + }; + + for (const auto& test : testcase) { + EXPECT_EQ(NormalizeCustomOpName(test.first), test.second); + } +} +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: FLAGS_logtostderr = true; + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc new file mode 100644 index 0000000000..8a921d7c5a --- /dev/null +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +namespace tflite { + +TfLiteRegistration* MutableOpResolver::FindOp( + tflite::BuiltinOperator op) const { + auto it = builtins_.find(op); + return it != builtins_.end() ? it->second : nullptr; +} + +TfLiteRegistration* MutableOpResolver::FindOp(const char* op) const { + auto it = custom_ops_.find(op); + return it != custom_ops_.end() ? it->second : nullptr; +} + +void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration) { + registration->builtin_code = op; + builtins_.insert(std::make_pair(op, registration)); +} + +void MutableOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration) { + registration->builtin_code = BuiltinOperator_CUSTOM; + custom_ops_.insert(std::make_pair(std::string(name), registration)); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h new file mode 100644 index 0000000000..9546c32427 --- /dev/null +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { + +// An OpResolver that is mutable, also used as the op in gen_op_registration. +// A typical usage: +// MutableOpResolver resolver; +// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); +// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); +// InterpreterBuilder(model, resolver)(&interpreter); +class MutableOpResolver : public OpResolver { + public: + MutableOpResolver() {} + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; + TfLiteRegistration* FindOp(const char* op) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); + void AddCustom(const char* name, TfLiteRegistration* registration); + + private: + std::unordered_map builtins_; + std::unordered_map custom_ops_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/contrib/lite/version.h new file mode 100644 index 0000000000..a751afabe7 --- /dev/null +++ b/tensorflow/contrib/lite/version.h @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ + +// The version number of the Schema. Ideally all changes will be backward +// compatible. If that ever changes, we must ensure that version is the first +// entry in the new tflite root so that we can see that version is not 1. +#define TFLITE_SCHEMA_VERSION (3) + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 8d4e4c23dc..cff672c9df 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -401,9 +401,14 @@ cmd_status(){ } # Run bazel build --nobuild to test the validity of the BUILD files +# TODO(mikecase): Remove TF Lite exclusion from this list. Exclusion is +# necessary since the @androidsdk WORKSPACE dependency is commented +# commented out by default in TF WORKSPACE file. do_bazel_nobuild() { BUILD_TARGET="//tensorflow/..." - BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} ${BUILD_TARGET}" + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/java/demo/app/src/main/..." + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/schema/..." + BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} -- ${BUILD_TARGET}" ${BUILD_CMD} diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index 5de5a379ac..df6016504c 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -33,4 +33,35 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --test_output=errors -- \ - //tensorflow/contrib/... + //tensorflow/contrib/... \ + -//tensorflow/contrib/lite/... \ + //tensorflow/contrib/lite:context_test \ + //tensorflow/contrib/lite:framework \ + //tensorflow/contrib/lite:interpreter_test \ + //tensorflow/contrib/lite:model_test \ + //tensorflow/contrib/lite/toco:toco \ + //tensorflow/contrib/lite:simple_memory_arena_test \ + //tensorflow/contrib/lite:string_util_test \ + //tensorflow/contrib/lite/kernels:activations_test \ + //tensorflow/contrib/lite/kernels:add_test \ + //tensorflow/contrib/lite/kernels:basic_rnn_test \ + //tensorflow/contrib/lite/kernels:concatenation_test \ + //tensorflow/contrib/lite/kernels:conv_test \ + //tensorflow/contrib/lite/kernels:depthwise_conv_test \ + //tensorflow/contrib/lite/kernels:embedding_lookup_test \ + //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test \ + //tensorflow/contrib/lite/kernels:fully_connected_test \ + //tensorflow/contrib/lite/testing:generated_examples_zip_test \ + //tensorflow/contrib/lite/kernels:hashtable_lookup_test \ + //tensorflow/contrib/lite/kernels:local_response_norm_test \ + //tensorflow/contrib/lite/kernels:lsh_projection_test \ + //tensorflow/contrib/lite/kernels:lstm_test \ + //tensorflow/contrib/lite/kernels:l2norm_test \ + //tensorflow/contrib/lite/kernels:mul_test \ + //tensorflow/contrib/lite/kernels:pooling_test \ + //tensorflow/contrib/lite/kernels:reshape_test \ + //tensorflow/contrib/lite/kernels:resize_bilinear_test \ + //tensorflow/contrib/lite/kernels:skip_gram_test \ + //tensorflow/contrib/lite/kernels:softmax_test \ + //tensorflow/contrib/lite/kernels:space_to_depth_test \ + //tensorflow/contrib/lite/kernels:svdf_test diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh index 8042522ef8..ddaaddc917 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh @@ -34,4 +34,4 @@ bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ - //tensorflow/contrib/... + //tensorflow/contrib/... -//tensorflow/contrib/lite/... diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c6e577223f..a3ab40ceef 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -158,6 +158,9 @@ sh_binary( "//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", + "//tensorflow/contrib/lite/toco:toco", + "//tensorflow/contrib/lite/toco/python:toco_wrapper", + "//tensorflow/contrib/lite/toco/python:toco_from_protos", "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in index ef6cf56421..86c5e4776d 100644 --- a/tensorflow/tools/pip_package/MANIFEST.in +++ b/tensorflow/tools/pip_package/MANIFEST.in @@ -4,6 +4,7 @@ recursive-include * *.so recursive-include * *.dll recursive-include * *.lib recursive-include * *.csv +recursive-include tensorflow/aux-bin * recursive-include tensorflow/include/tensorflow *.h recursive-include tensorflow/include/Eigen * recursive-include tensorflow/include/external * diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index cbf06a97d0..8249703ba7 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -137,6 +137,9 @@ function main() { fi fi fi + # Install toco as a binary in aux-bin. + mkdir "${TMPDIR}/tensorflow/aux-bin" + cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/ fi # protobuf pip package doesn't ship with header files. Copy the headers diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 456c2e2908..60282f6aa3 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -69,6 +69,8 @@ if sys.version_info < (3, 4): # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:main', + 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', + 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', # We need to keep the TensorBoard command, even though the console script # is now declared by the tensorboard pip package. If we remove the @@ -188,7 +190,6 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*', 'external/eigen_archive')) + list(find_files('*.h', 'external/nsync/public'))) - setup( name=project_name, version=_VERSION.replace('-', ''), diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD index a426db0c50..e1563103c8 100644 --- a/third_party/flatbuffers/flatbuffers.BUILD +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -104,6 +104,10 @@ cc_binary( "grpc/", "include/", ], + linkopts = [ + "-lm", + "-ldl", + ], deps = [ ":flatc_library", ], -- cgit v1.2.3